diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 82ba95444f..e1b7c8f4e8 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -20,14 +20,40 @@ jobs: cd api uv sync --dev # Fix lint errors - uv run ruff check --fix-only . + uv run ruff check --fix . # Format code uv run ruff format . + - name: ast-grep run: | uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all + + - name: mdformat run: | uvx mdformat . + + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + package_json_file: web/package.json + run_install: false + + - name: Setup NodeJS + uses: actions/setup-node@v4 + with: + node-version: 22 + cache: pnpm + cache-dependency-path: ./web/package.json + + - name: Web dependencies + working-directory: ./web + run: pnpm install --frozen-lockfile + + - name: oxlint + working-directory: ./web + run: | + pnpx oxlint --fix + - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 diff --git a/.github/workflows/deploy-enterprise.yml b/.github/workflows/deploy-enterprise.yml index 98fa7c3b49..9cff3a3482 100644 --- a/.github/workflows/deploy-enterprise.yml +++ b/.github/workflows/deploy-enterprise.yml @@ -19,11 +19,23 @@ jobs: github.event.workflow_run.head_branch == 'deploy/enterprise' steps: - - name: Deploy to server - uses: appleboy/ssh-action@v0.1.8 - with: - host: ${{ secrets.ENTERPRISE_SSH_HOST }} - username: ${{ secrets.ENTERPRISE_SSH_USER }} - password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }} - script: | - ${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }} + - name: trigger deployments + env: + DEV_ENV_ADDRS: ${{ vars.DEV_ENV_ADDRS }} + DEPLOY_SECRET: ${{ secrets.DEPLOY_SECRET }} + run: | + IFS=',' read -ra ENDPOINTS <<< "${DEV_ENV_ADDRS:-}" + BODY='{"project":"dify-api","tag":"deploy-enterprise"}' + + for ENDPOINT in "${ENDPOINTS[@]}"; do + ENDPOINT="$(echo "$ENDPOINT" | xargs)" + [ -z "$ENDPOINT" ] && continue + + API_SIGNATURE=$(printf '%s' "$BODY" | openssl dgst -sha256 -hmac "$DEPLOY_SECRET" | awk '{print "sha256="$2}') + + curl -sSf -X POST \ + -H "Content-Type: application/json" \ + -H "X-Hub-Signature-256: $API_SIGNATURE" \ + -d "$BODY" \ + "$ENDPOINT" + done diff --git a/.gitignore b/.gitignore index bc354e639e..cbb7b4dac0 100644 --- a/.gitignore +++ b/.gitignore @@ -227,3 +227,7 @@ web/public/fallback-*.js .roo/ api/.env.backup /clickzetta + +# Benchmark +scripts/stress-test/setup/config/ +scripts/stress-test/reports/ \ No newline at end of file diff --git a/Makefile b/Makefile index d82f6f24ad..ec7df3e03d 100644 --- a/Makefile +++ b/Makefile @@ -4,10 +4,13 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web API_IMAGE=$(DOCKER_REGISTRY)/dify-api VERSION=latest +# Default target - show help +.DEFAULT_GOAL := help + # Backend Development Environment Setup .PHONY: dev-setup prepare-docker prepare-web prepare-api -# Default dev setup target +# Dev setup target dev-setup: prepare-docker prepare-web prepare-api @echo "✅ Backend development environment setup complete!" @@ -46,6 +49,27 @@ dev-clean: @rm -rf api/storage @echo "✅ Cleanup complete" +# Backend Code Quality Commands +format: + @echo "🎨 Running ruff format..." + @uv run --project api --dev ruff format ./api + @echo "✅ Code formatting complete" + +check: + @echo "🔍 Running ruff check..." + @uv run --project api --dev ruff check ./api + @echo "✅ Code check complete" + +lint: + @echo "🔧 Running ruff format and check with fixes..." + @uv run --directory api --dev sh -c 'ruff format ./api && ruff check --fix ./api' + @echo "✅ Linting complete" + +type-check: + @echo "📝 Running type check with basedpyright..." + @uv run --directory api --dev basedpyright + @echo "✅ Type check complete" + # Build Docker images build-web: @echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..." @@ -90,6 +114,12 @@ help: @echo " make prepare-api - Set up API environment" @echo " make dev-clean - Stop Docker middleware containers" @echo "" + @echo "Backend Code Quality:" + @echo " make format - Format code with ruff" + @echo " make check - Check code with ruff" + @echo " make lint - Format and fix code with ruff" + @echo " make type-check - Run type checking with basedpyright" + @echo "" @echo "Docker Build Targets:" @echo " make build-web - Build web Docker image" @echo " make build-api - Build API Docker image" @@ -98,4 +128,4 @@ help: @echo " make build-push-all - Build and push all Docker images" # Phony targets -.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help +.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check diff --git a/api/.env.example b/api/.env.example index 8d783af134..2986402e9e 100644 --- a/api/.env.example +++ b/api/.env.example @@ -530,6 +530,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5 CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 diff --git a/api/.ruff.toml b/api/.ruff.toml index 9668dc9f76..9a15754d9a 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -45,6 +45,7 @@ select = [ "G001", # don't use str format to logging messages "G003", # don't use + in logging messages "G004", # don't use f-strings to format logging messages + "UP042", # use StrEnum ] ignore = [ diff --git a/api/commands.py b/api/commands.py index 9b13cc2e1a..d46750316b 100644 --- a/api/commands.py +++ b/api/commands.py @@ -212,7 +212,9 @@ def migrate_annotation_vector_database(): if not dataset_collection_binding: click.echo(f"App annotation collection binding not found: {app.id}") continue - annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() + annotations = db.session.scalars( + select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) + ).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, @@ -367,29 +369,25 @@ def migrate_knowledge_vector_database(): ) raise e - dataset_documents = ( - db.session.query(DatasetDocument) - .where( + dataset_documents = db.session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset.id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() documents = [] segments_count = 0 for dataset_document in dataset_documents: - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.status == "completed", DocumentSegment.enabled == True, ) - .all() - ) + ).all() for segment in segments: document = Document( @@ -479,12 +477,12 @@ def convert_to_agent_apps(): click.echo(f"Converting app: {app.id}") try: - app.mode = AppMode.AGENT_CHAT.value + app.mode = AppMode.AGENT_CHAT db.session.commit() # update conversation mode to agent db.session.query(Conversation).where(Conversation.app_id == app.id).update( - {Conversation.mode: AppMode.AGENT_CHAT.value} + {Conversation.mode: AppMode.AGENT_CHAT} ) db.session.commit() @@ -511,7 +509,7 @@ def add_qdrant_index(field: str): from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import PayloadSchemaType - from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig + from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig for binding in bindings: if dify_config.QDRANT_URL is None: @@ -525,7 +523,21 @@ def add_qdrant_index(field: str): prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, ) try: - client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) + params = qdrant_config.to_qdrant_params() + # Check the type before using + if isinstance(params, PathQdrantParams): + # PathQdrantParams case + client = qdrant_client.QdrantClient(path=params.path) + else: + # UrlQdrantParams case - params is UrlQdrantParams + client = qdrant_client.QdrantClient( + url=params.url, + api_key=params.api_key, + timeout=int(params.timeout), + verify=params.verify, + grpc_port=params.grpc_port, + prefer_grpc=params.prefer_grpc, + ) # create payload index client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) create_count += 1 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 899fecea7c..0d6f4e416e 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -31,6 +31,12 @@ class SecurityConfig(BaseSettings): description="Duration in minutes for which a password reset token remains valid", default=5, ) + + EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a email register token remains valid", + default=5, + ) + CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( description="Duration in minutes for which a change email token remains valid", default=5, @@ -639,6 +645,11 @@ class AuthConfig(BaseSettings): default=86400, ) + EMAIL_REGISTER_LOCKOUT_DURATION: PositiveInt = Field( + description="Time (in seconds) a user must wait before retrying email register after exceeding the rate limit.", + default=86400, + ) + class ModerationConfig(BaseSettings): """ diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 9fd9b60194..9700447a4c 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -1,4 +1,4 @@ -import enum +from enum import Enum from typing import Literal, Optional from pydantic import Field, PositiveInt @@ -10,7 +10,7 @@ class OpenSearchConfig(BaseSettings): Configuration settings for OpenSearch """ - class AuthMethod(enum.StrEnum): + class AuthMethod(Enum): """ Authentication method for OpenSearch """ diff --git a/api/constants/__init__.py b/api/constants/__init__.py index c98f4d55c8..fe8f4f8785 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -16,14 +16,14 @@ AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"] AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) +_doc_extensions: list[str] if dify_config.ETL_TYPE == "Unstructured": - DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] - DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) + _doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] + _doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) if dify_config.UNSTRUCTURED_API_URL: - DOCUMENT_EXTENSIONS.append("ppt") - DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) + _doc_extensions.append("ppt") else: - DOCUMENT_EXTENSIONS = [ + _doc_extensions = [ "txt", "markdown", "md", @@ -38,4 +38,4 @@ else: "vtt", "properties", ] - DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) +DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions] diff --git a/api/constants/model_template.py b/api/constants/model_template.py index c26d8c0186..cacf6b6874 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -7,7 +7,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # workflow default mode AppMode.WORKFLOW: { "app": { - "mode": AppMode.WORKFLOW.value, + "mode": AppMode.WORKFLOW, "enable_site": True, "enable_api": True, } @@ -15,7 +15,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # completion default mode AppMode.COMPLETION: { "app": { - "mode": AppMode.COMPLETION.value, + "mode": AppMode.COMPLETION, "enable_site": True, "enable_api": True, }, @@ -44,7 +44,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # chat default mode AppMode.CHAT: { "app": { - "mode": AppMode.CHAT.value, + "mode": AppMode.CHAT, "enable_site": True, "enable_api": True, }, @@ -60,7 +60,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # advanced-chat default mode AppMode.ADVANCED_CHAT: { "app": { - "mode": AppMode.ADVANCED_CHAT.value, + "mode": AppMode.ADVANCED_CHAT, "enable_site": True, "enable_api": True, }, @@ -68,7 +68,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # agent-chat default mode AppMode.AGENT_CHAT: { "app": { - "mode": AppMode.AGENT_CHAT.value, + "mode": AppMode.AGENT_CHAT, "enable_site": True, "enable_api": True, }, diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index ae41a2c03a..a07e6a08a6 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: from core.model_runtime.entities.model_entities import AIModelEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController - from core.workflow.entities.variable_pool import VariablePool """ diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 5ad7645969..e13edf6a37 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -1,4 +1,5 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi @@ -26,7 +27,16 @@ from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi bp = Blueprint("console", __name__, url_prefix="/console/api") -api = ExternalApi(bp) + +api = ExternalApi( + bp, + version="1.0", + title="Console API", + description="Console management APIs for app configuration, monitoring, and administration", +) + +# Create namespace +console_ns = Namespace("console", description="Console management API operations", path="/") # File api.add_resource(FileApi, "/files/upload") @@ -43,7 +53,16 @@ api.add_resource(AppImportConfirmApi, "/apps/imports//confirm" api.add_resource(AppImportCheckDependenciesApi, "/apps/imports//check-dependencies") # Import other controllers -from . import admin, apikey, extension, feature, ping, setup, version +from . import ( + admin, + apikey, + extension, + feature, + init_validate, + ping, + setup, + version, +) # Import app controllers from .app import ( @@ -70,7 +89,16 @@ from .app import ( ) # Import auth controllers -from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server +from .auth import ( + activate, + data_source_bearer_auth, + data_source_oauth, + email_register, + forgot_password, + login, + oauth, + oauth_server, +) # Import billing controllers from .billing import billing, compliance @@ -95,6 +123,23 @@ from .explore import ( saved_message, ) +# Import tag controllers +from .tag import tags + +# Import workspace controllers +from .workspace import ( + account, + agent_providers, + endpoint, + load_balancing_config, + members, + model_providers, + models, + plugin, + tool_providers, + workspace, +) + # Explore Audio api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") @@ -166,19 +211,71 @@ api.add_resource( InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" ) -# Import tag controllers -from .tag import tags +api.add_namespace(console_ns) -# Import workspace controllers -from .workspace import ( - account, - agent_providers, - endpoint, - load_balancing_config, - members, - model_providers, - models, - plugin, - tool_providers, - workspace, -) +__all__ = [ + "account", + "activate", + "admin", + "advanced_prompt_template", + "agent", + "agent_providers", + "annotation", + "api", + "apikey", + "app", + "audio", + "billing", + "bp", + "completion", + "compliance", + "console_ns", + "conversation", + "conversation_variables", + "data_source", + "data_source_bearer_auth", + "data_source_oauth", + "datasets", + "datasets_document", + "datasets_segments", + "email_register", + "endpoint", + "extension", + "external", + "feature", + "forgot_password", + "generator", + "hit_testing", + "init_validate", + "installed_app", + "load_balancing_config", + "login", + "mcp_server", + "members", + "message", + "metadata", + "model_config", + "model_providers", + "models", + "oauth", + "oauth_server", + "ops_trace", + "parameter", + "ping", + "plugin", + "recommended_app", + "saved_message", + "setup", + "site", + "statistic", + "tags", + "tool_providers", + "version", + "website", + "workflow", + "workflow_app_log", + "workflow_draft_variable", + "workflow_run", + "workflow_statistic", + "workspace", +] diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 1306efacf4..93f242ad28 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -3,7 +3,7 @@ from functools import wraps from typing import ParamSpec, TypeVar from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized @@ -12,7 +12,7 @@ P = ParamSpec("P") R = TypeVar("R") from configs import dify_config from constants.languages import supported_language -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import only_edition_cloud from extensions.ext_database import db from models.model import App, InstalledApp, RecommendedApp @@ -45,7 +45,28 @@ def admin_required(view: Callable[P, R]): return decorated +@console_ns.route("/admin/insert-explore-apps") class InsertExploreAppListApi(Resource): + @api.doc("insert_explore_app") + @api.doc(description="Insert or update an app in the explore list") + @api.expect( + api.model( + "InsertExploreAppRequest", + { + "app_id": fields.String(required=True, description="Application ID"), + "desc": fields.String(description="App description"), + "copyright": fields.String(description="Copyright information"), + "privacy_policy": fields.String(description="Privacy policy"), + "custom_disclaimer": fields.String(description="Custom disclaimer"), + "language": fields.String(required=True, description="Language code"), + "category": fields.String(required=True, description="App category"), + "position": fields.Integer(required=True, description="Display position"), + }, + ) + ) + @api.response(200, "App updated successfully") + @api.response(201, "App inserted successfully") + @api.response(404, "App not found") @only_edition_cloud @admin_required def post(self): @@ -115,7 +136,12 @@ class InsertExploreAppListApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/admin/insert-explore-apps/") class InsertExploreAppApi(Resource): + @api.doc("delete_explore_app") + @api.doc(description="Remove an app from the explore list") + @api.doc(params={"app_id": "Application ID to remove"}) + @api.response(204, "App removed successfully") @only_edition_cloud @admin_required def delete(self, app_id): @@ -152,7 +178,3 @@ class InsertExploreAppApi(Resource): db.session.commit() return {"result": "success"}, 204 - - -api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps") -api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/") diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index cfd5f73ade..56c61c2886 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,8 +1,9 @@ -from typing import Any, Optional +from typing import Optional import flask_restx from flask_login import current_user from flask_restx import Resource, fields, marshal_with +from flask_restx._http import HTTPStatus from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -13,7 +14,7 @@ from libs.login import login_required from models.dataset import Dataset from models.model import ApiToken, App -from . import api +from . import api, console_ns from .wraps import account_initialization_required, setup_required api_key_fields = { @@ -40,7 +41,7 @@ def _get_resource(resource_id, tenant_id, resource_model): ).scalar_one_or_none() if resource is None: - flask_restx.abort(404, message=f"{resource_model.__name__} not found.") + flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.") return resource @@ -49,7 +50,7 @@ class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[Any] = None + resource_model: Optional[type] = None resource_id_field: str | None = None token_prefix: str | None = None max_keys = 10 @@ -59,11 +60,11 @@ class BaseApiKeyListResource(Resource): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) - keys = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) - .all() - ) + keys = db.session.scalars( + select(ApiToken).where( + ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id + ) + ).all() return {"items": keys} @marshal_with(api_key_fields) @@ -82,7 +83,7 @@ class BaseApiKeyListResource(Resource): if current_key_count >= self.max_keys: flask_restx.abort( - 400, + HTTPStatus.BAD_REQUEST, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", custom="max_keys_exceeded", ) @@ -102,7 +103,7 @@ class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[Any] = None + resource_model: Optional[type] = None resource_id_field: str | None = None def delete(self, resource_id, api_key_id): @@ -126,7 +127,7 @@ class BaseApiKeyResource(Resource): ) if key is None: - flask_restx.abort(404, message="API key not found") + flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found") db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() @@ -134,7 +135,25 @@ class BaseApiKeyResource(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//api-keys") class AppApiKeyListResource(BaseApiKeyListResource): + @api.doc("get_app_api_keys") + @api.doc(description="Get all API keys for an app") + @api.doc(params={"resource_id": "App ID"}) + @api.response(200, "Success", api_key_list) + def get(self, resource_id): + """Get all API keys for an app""" + return super().get(resource_id) + + @api.doc("create_app_api_key") + @api.doc(description="Create a new API key for an app") + @api.doc(params={"resource_id": "App ID"}) + @api.response(201, "API key created successfully", api_key_fields) + @api.response(400, "Maximum keys exceeded") + def post(self, resource_id): + """Create a new API key for an app""" + return super().post(resource_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -146,7 +165,16 @@ class AppApiKeyListResource(BaseApiKeyListResource): token_prefix = "app-" +@console_ns.route("/apps//api-keys/") class AppApiKeyResource(BaseApiKeyResource): + @api.doc("delete_app_api_key") + @api.doc(description="Delete an API key for an app") + @api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"}) + @api.response(204, "API key deleted successfully") + def delete(self, resource_id, api_key_id): + """Delete an API key for an app""" + return super().delete(resource_id, api_key_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -157,7 +185,25 @@ class AppApiKeyResource(BaseApiKeyResource): resource_id_field = "app_id" +@console_ns.route("/datasets//api-keys") class DatasetApiKeyListResource(BaseApiKeyListResource): + @api.doc("get_dataset_api_keys") + @api.doc(description="Get all API keys for a dataset") + @api.doc(params={"resource_id": "Dataset ID"}) + @api.response(200, "Success", api_key_list) + def get(self, resource_id): + """Get all API keys for a dataset""" + return super().get(resource_id) + + @api.doc("create_dataset_api_key") + @api.doc(description="Create a new API key for a dataset") + @api.doc(params={"resource_id": "Dataset ID"}) + @api.response(201, "API key created successfully", api_key_fields) + @api.response(400, "Maximum keys exceeded") + def post(self, resource_id): + """Create a new API key for a dataset""" + return super().post(resource_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -169,7 +215,16 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): token_prefix = "ds-" +@console_ns.route("/datasets//api-keys/") class DatasetApiKeyResource(BaseApiKeyResource): + @api.doc("delete_dataset_api_key") + @api.doc(description="Delete an API key for a dataset") + @api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"}) + @api.response(204, "API key deleted successfully") + def delete(self, resource_id, api_key_id): + """Delete an API key for a dataset""" + return super().delete(resource_id, api_key_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -178,9 +233,3 @@ class DatasetApiKeyResource(BaseApiKeyResource): resource_type = "dataset" resource_model = Dataset resource_id_field = "dataset_id" - - -api.add_resource(AppApiKeyListResource, "/apps//api-keys") -api.add_resource(AppApiKeyResource, "/apps//api-keys/") -api.add_resource(DatasetApiKeyListResource, "/datasets//api-keys") -api.add_resource(DatasetApiKeyResource, "/datasets//api-keys/") diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c6cb6f6e3a..315825db79 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,12 +1,26 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.advanced_prompt_template_service import AdvancedPromptTemplateService +@console_ns.route("/app/prompt-templates") class AdvancedPromptTemplateList(Resource): + @api.doc("get_advanced_prompt_templates") + @api.doc(description="Get advanced prompt templates based on app mode and model configuration") + @api.expect( + api.parser() + .add_argument("app_mode", type=str, required=True, location="args", help="Application mode") + .add_argument("model_mode", type=str, required=True, location="args", help="Model mode") + .add_argument("has_context", type=str, default="true", location="args", help="Whether has context") + .add_argument("model_name", type=str, required=True, location="args", help="Model name") + ) + @api.response( + 200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data")) + ) + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -19,6 +33,3 @@ class AdvancedPromptTemplateList(Resource): args = parser.parse_args() return AdvancedPromptTemplateService.get_prompt(args) - - -api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates") diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index a964154207..c063f336c7 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,6 +1,6 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from libs.helper import uuid_value @@ -9,7 +9,18 @@ from models.model import AppMode from services.agent_service import AgentService +@console_ns.route("/apps//agent/logs") class AgentLogApi(Resource): + @api.doc("get_agent_logs") + @api.doc(description="Get agent execution logs for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("message_id", type=str, required=True, location="args", help="Message UUID") + .add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID") + ) + @api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))) + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -23,6 +34,3 @@ class AgentLogApi(Resource): args = parser.parse_args() return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) - - -api.add_resource(AgentLogApi, "/apps//agent/logs") diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 37d23ccd9f..d0ee11fe75 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -2,11 +2,11 @@ from typing import Literal from flask import request from flask_login import current_user -from flask_restx import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.common.errors import NoFileUploadedError, TooManyFilesError -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, @@ -21,7 +21,23 @@ from libs.login import login_required from services.annotation_service import AppAnnotationService +@console_ns.route("/apps//annotation-reply/") class AnnotationReplyActionApi(Resource): + @api.doc("annotation_reply_action") + @api.doc(description="Enable or disable annotation reply for an app") + @api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"}) + @api.expect( + api.model( + "AnnotationReplyActionRequest", + { + "score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"), + "embedding_provider_name": fields.String(required=True, description="Embedding provider name"), + "embedding_model_name": fields.String(required=True, description="Embedding model name"), + }, + ) + ) + @api.response(200, "Action completed successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -43,7 +59,13 @@ class AnnotationReplyActionApi(Resource): return result, 200 +@console_ns.route("/apps//annotation-setting") class AppAnnotationSettingDetailApi(Resource): + @api.doc("get_annotation_setting") + @api.doc(description="Get annotation settings for an app") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Annotation settings retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -56,7 +78,23 @@ class AppAnnotationSettingDetailApi(Resource): return result, 200 +@console_ns.route("/apps//annotation-settings/") class AppAnnotationSettingUpdateApi(Resource): + @api.doc("update_annotation_setting") + @api.doc(description="Update annotation settings for an app") + @api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"}) + @api.expect( + api.model( + "AnnotationSettingUpdateRequest", + { + "score_threshold": fields.Float(required=True, description="Score threshold"), + "embedding_provider_name": fields.String(required=True, description="Embedding provider"), + "embedding_model_name": fields.String(required=True, description="Embedding model"), + }, + ) + ) + @api.response(200, "Settings updated successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -75,7 +113,13 @@ class AppAnnotationSettingUpdateApi(Resource): return result, 200 +@console_ns.route("/apps//annotation-reply//status/") class AnnotationReplyActionStatusApi(Resource): + @api.doc("get_annotation_reply_action_status") + @api.doc(description="Get status of annotation reply action job") + @api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"}) + @api.response(200, "Job status retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -99,7 +143,19 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 +@console_ns.route("/apps//annotations") class AnnotationApi(Resource): + @api.doc("list_annotations") + @api.doc(description="Get annotations for an app with pagination") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size") + .add_argument("keyword", type=str, location="args", default="", help="Search keyword") + ) + @api.response(200, "Annotations retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -122,6 +178,21 @@ class AnnotationApi(Resource): } return response, 200 + @api.doc("create_annotation") + @api.doc(description="Create a new annotation for an app") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "CreateAnnotationRequest", + { + "question": fields.String(required=True, description="Question text"), + "answer": fields.String(required=True, description="Answer text"), + "annotation_reply": fields.Raw(description="Annotation reply data"), + }, + ) + ) + @api.response(201, "Annotation created successfully", annotation_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -168,7 +239,13 @@ class AnnotationApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//annotations/export") class AnnotationExportApi(Resource): + @api.doc("export_annotations") + @api.doc(description="Export all annotations for an app") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields))) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -182,7 +259,14 @@ class AnnotationExportApi(Resource): return response, 200 +@console_ns.route("/apps//annotations/") class AnnotationUpdateDeleteApi(Resource): + @api.doc("update_delete_annotation") + @api.doc(description="Update or delete an annotation") + @api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) + @api.response(200, "Annotation updated successfully", annotation_fields) + @api.response(204, "Annotation deleted successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -214,7 +298,14 @@ class AnnotationUpdateDeleteApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//annotations/batch-import") class AnnotationBatchImportApi(Resource): + @api.doc("batch_import_annotations") + @api.doc(description="Batch import annotations from CSV file") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Batch import started successfully") + @api.response(403, "Insufficient permissions") + @api.response(400, "No file uploaded or too many files") @setup_required @login_required @account_initialization_required @@ -239,7 +330,13 @@ class AnnotationBatchImportApi(Resource): return AppAnnotationService.batch_import_app_annotations(app_id, file) +@console_ns.route("/apps//annotations/batch-import-status/") class AnnotationBatchImportStatusApi(Resource): + @api.doc("get_batch_import_status") + @api.doc(description="Get status of batch import job") + @api.doc(params={"app_id": "Application ID", "job_id": "Job ID"}) + @api.response(200, "Job status retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -262,7 +359,20 @@ class AnnotationBatchImportStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 +@console_ns.route("/apps//annotations//hit-histories") class AnnotationHitHistoryListApi(Resource): + @api.doc("list_annotation_hit_histories") + @api.doc(description="Get hit histories for an annotation") + @api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) + @api.expect( + api.parser() + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size") + ) + @api.response( + 200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields)) + ) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -285,17 +395,3 @@ class AnnotationHitHistoryListApi(Resource): "page": page, } return response - - -api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply/") -api.add_resource( - AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" -) -api.add_resource(AnnotationApi, "/apps//annotations") -api.add_resource(AnnotationExportApi, "/apps//annotations/export") -api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") -api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") -api.add_resource(AnnotationBatchImportStatusApi, "/apps//annotations/batch-import-status/") -api.add_resource(AnnotationHitHistoryListApi, "/apps//annotations//hit-histories") -api.add_resource(AppAnnotationSettingDetailApi, "/apps//annotation-setting") -api.add_resource(AppAnnotationSettingUpdateApi, "/apps//annotation-settings/") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 10753d2f95..2d2e4b448a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -2,12 +2,12 @@ import uuid from typing import cast from flask_login import current_user -from flask_restx import Resource, inputs, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( account_initialization_required, @@ -34,7 +34,27 @@ def _validate_description_length(description): return description +@console_ns.route("/apps") class AppListApi(Resource): + @api.doc("list_apps") + @api.doc(description="Get list of applications with pagination and filtering") + @api.expect( + api.parser() + .add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1) + .add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20) + .add_argument( + "mode", + type=str, + location="args", + choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"], + default="all", + help="App mode filter", + ) + .add_argument("name", type=str, location="args", help="Filter by app name") + .add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs") + .add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator") + ) + @api.response(200, "Success", app_pagination_fields) @setup_required @login_required @account_initialization_required @@ -91,6 +111,24 @@ class AppListApi(Resource): return marshal(app_pagination, app_pagination_fields), 200 + @api.doc("create_app") + @api.doc(description="Create a new application") + @api.expect( + api.model( + "CreateAppRequest", + { + "name": fields.String(required=True, description="App name"), + "description": fields.String(description="App description (max 400 chars)"), + "mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + }, + ) + ) + @api.response(201, "App created successfully", app_detail_fields) + @api.response(403, "Insufficient permissions") + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -115,12 +153,21 @@ class AppListApi(Resource): raise BadRequest("mode is required") app_service = AppService() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + if current_user.current_tenant_id is None: + raise ValueError("current_user.current_tenant_id cannot be None") app = app_service.create_app(current_user.current_tenant_id, args, current_user) return app, 201 +@console_ns.route("/apps/") class AppApi(Resource): + @api.doc("get_app_detail") + @api.doc(description="Get application details") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Success", app_detail_fields_with_site) @setup_required @login_required @account_initialization_required @@ -139,6 +186,26 @@ class AppApi(Resource): return app_model + @api.doc("update_app") + @api.doc(description="Update application details") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "UpdateAppRequest", + { + "name": fields.String(required=True, description="App name"), + "description": fields.String(description="App description (max 400 chars)"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + "use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"), + "max_active_requests": fields.Integer(description="Maximum active requests"), + }, + ) + ) + @api.response(200, "App updated successfully", app_detail_fields_with_site) + @api.response(403, "Insufficient permissions") + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -161,14 +228,31 @@ class AppApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app(app_model, args) + # Construct ArgsDict from parsed arguments + from services.app_service import AppService as AppServiceType + + args_dict: AppServiceType.ArgsDict = { + "name": args["name"], + "description": args.get("description", ""), + "icon_type": args.get("icon_type", ""), + "icon": args.get("icon", ""), + "icon_background": args.get("icon_background", ""), + "use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False), + "max_active_requests": args.get("max_active_requests", 0), + } + app_model = app_service.update_app(app_model, args_dict) return app_model + @api.doc("delete_app") + @api.doc(description="Delete application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(204, "App deleted successfully") + @api.response(403, "Insufficient permissions") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def delete(self, app_model): """Delete app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -181,7 +265,25 @@ class AppApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//copy") class AppCopyApi(Resource): + @api.doc("copy_app") + @api.doc(description="Create a copy of an existing application") + @api.doc(params={"app_id": "Application ID to copy"}) + @api.expect( + api.model( + "CopyAppRequest", + { + "name": fields.String(description="Name for the copied app"), + "description": fields.String(description="Description for the copied app"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + }, + ) + ) + @api.response(201, "App copied successfully", app_detail_fields_with_site) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -223,11 +325,26 @@ class AppCopyApi(Resource): return app, 201 +@console_ns.route("/apps//export") class AppExportApi(Resource): + @api.doc("export_app") + @api.doc(description="Export application configuration as DSL") + @api.doc(params={"app_id": "Application ID to export"}) + @api.expect( + api.parser() + .add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export") + .add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export") + ) + @api.response( + 200, + "App exported successfully", + api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}), + ) + @api.response(403, "Insufficient permissions") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): """Export app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -247,7 +364,13 @@ class AppExportApi(Resource): } +@console_ns.route("/apps//name") class AppNameApi(Resource): + @api.doc("check_app_name") + @api.doc(description="Check if app name is available") + @api.doc(params={"app_id": "Application ID"}) + @api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check")) + @api.response(200, "Name availability checked") @setup_required @login_required @account_initialization_required @@ -263,12 +386,28 @@ class AppNameApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_name(app_model, args.get("name")) + app_model = app_service.update_app_name(app_model, args["name"]) return app_model +@console_ns.route("/apps//icon") class AppIconApi(Resource): + @api.doc("update_app_icon") + @api.doc(description="Update application icon") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppIconRequest", + { + "icon": fields.String(required=True, description="Icon data"), + "icon_type": fields.String(description="Icon type"), + "icon_background": fields.String(description="Icon background color"), + }, + ) + ) + @api.response(200, "Icon updated successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -285,12 +424,23 @@ class AppIconApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) + app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "") return app_model +@console_ns.route("/apps//site-enable") class AppSiteStatus(Resource): + @api.doc("update_app_site_status") + @api.doc(description="Enable or disable app site") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")} + ) + ) + @api.response(200, "Site status updated successfully", app_detail_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -306,12 +456,23 @@ class AppSiteStatus(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) + app_model = app_service.update_app_site_status(app_model, args["enable_site"]) return app_model +@console_ns.route("/apps//api-enable") class AppApiStatus(Resource): + @api.doc("update_app_api_status") + @api.doc(description="Enable or disable app API") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")} + ) + ) + @api.response(200, "API status updated successfully", app_detail_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -327,12 +488,17 @@ class AppApiStatus(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) + app_model = app_service.update_app_api_status(app_model, args["enable_api"]) return app_model +@console_ns.route("/apps//trace") class AppTraceApi(Resource): + @api.doc("get_app_trace") + @api.doc(description="Get app tracing configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Trace configuration retrieved successfully") @setup_required @login_required @account_initialization_required @@ -342,6 +508,20 @@ class AppTraceApi(Resource): return app_trace_config + @api.doc("update_app_trace") + @api.doc(description="Update app tracing configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppTraceRequest", + { + "enabled": fields.Boolean(required=True, description="Enable or disable tracing"), + "tracing_provider": fields.String(required=True, description="Tracing provider"), + }, + ) + ) + @api.response(200, "Trace configuration updated successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -361,14 +541,3 @@ class AppTraceApi(Resource): ) return {"result": "success"} - - -api.add_resource(AppListApi, "/apps") -api.add_resource(AppApi, "/apps/") -api.add_resource(AppCopyApi, "/apps//copy") -api.add_resource(AppExportApi, "/apps//export") -api.add_resource(AppNameApi, "/apps//name") -api.add_resource(AppIconApi, "/apps//icon") -api.add_resource(AppSiteStatus, "/apps//site-enable") -api.add_resource(AppApiStatus, "/apps//api-enable") -api.add_resource(AppTraceApi, "/apps//trace") diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index aaf5c3dfaa..7d659dae0d 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,11 +1,11 @@ import logging from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from werkzeug.exceptions import InternalServerError import services -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -34,7 +34,18 @@ from services.errors.audio import ( logger = logging.getLogger(__name__) +@console_ns.route("/apps//audio-to-text") class ChatMessageAudioApi(Resource): + @api.doc("chat_message_audio_transcript") + @api.doc(description="Transcript audio to text for chat messages") + @api.doc(params={"app_id": "App ID"}) + @api.response( + 200, + "Audio transcription successful", + api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), + ) + @api.response(400, "Bad request - No audio uploaded or unsupported type") + @api.response(413, "Audio file too large") @setup_required @login_required @account_initialization_required @@ -76,11 +87,28 @@ class ChatMessageAudioApi(Resource): raise InternalServerError() +@console_ns.route("/apps//text-to-audio") class ChatMessageTextApi(Resource): + @api.doc("chat_message_text_to_speech") + @api.doc(description="Convert text to speech for chat messages") + @api.doc(params={"app_id": "App ID"}) + @api.expect( + api.model( + "TextToSpeechRequest", + { + "message_id": fields.String(description="Message ID"), + "text": fields.String(required=True, description="Text to convert to speech"), + "voice": fields.String(description="Voice to use for TTS"), + "streaming": fields.Boolean(description="Whether to stream the audio"), + }, + ) + ) + @api.response(200, "Text to speech conversion successful") + @api.response(400, "Bad request - Invalid parameters") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def post(self, app_model: App): try: parser = reqparse.RequestParser() @@ -124,11 +152,18 @@ class ChatMessageTextApi(Resource): raise InternalServerError() +@console_ns.route("/apps//text-to-audio/voices") class TextModesApi(Resource): + @api.doc("get_text_to_speech_voices") + @api.doc(description="Get available TTS voices for a specific language") + @api.doc(params={"app_id": "App ID"}) + @api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code")) + @api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))) + @api.response(400, "Invalid language parameter") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): try: parser = reqparse.RequestParser() @@ -164,8 +199,3 @@ class TextModesApi(Resource): except Exception as e: logger.exception("Failed to handle get request to TextModesApi") raise InternalServerError() - - -api.add_resource(ChatMessageAudioApi, "/apps//audio-to-text") -api.add_resource(ChatMessageTextApi, "/apps//text-to-audio") -api.add_resource(TextModesApi, "/apps//text-to-audio/voices") diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 701ebb0b4a..2f7b90e7fb 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,12 +1,11 @@ import logging -import flask_login from flask import request -from flask_restx import Resource, reqparse -from werkzeug.exceptions import InternalServerError, NotFound +from flask_restx import Resource, fields, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -29,7 +28,8 @@ from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value -from libs.login import login_required +from libs.login import current_user, login_required +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -38,7 +38,27 @@ logger = logging.getLogger(__name__) # define completion message api for user +@console_ns.route("/apps//completion-messages") class CompletionMessageApi(Resource): + @api.doc("create_completion_message") + @api.doc(description="Generate completion message for debugging") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "CompletionMessageRequest", + { + "inputs": fields.Raw(required=True, description="Input variables"), + "query": fields.String(description="Query text", default=""), + "files": fields.List(fields.Raw(), description="Uploaded files"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"), + "retriever_from": fields.String(default="dev", description="Retriever source"), + }, + ) + ) + @api.response(200, "Completion generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(404, "App not found") @setup_required @login_required @account_initialization_required @@ -56,11 +76,11 @@ class CompletionMessageApi(Resource): streaming = args["response_mode"] != "blocking" args["auto_generate_name"] = False - account = flask_login.current_user - try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account or EndUser instance") response = AppGenerateService.generate( - app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -86,25 +106,58 @@ class CompletionMessageApi(Resource): raise InternalServerError() +@console_ns.route("/apps//completion-messages//stop") class CompletionMessageStopApi(Resource): + @api.doc("stop_completion_message") + @api.doc(description="Stop a running completion message generation") + @api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"}) + @api.response(200, "Task stopped successfully") @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model, task_id): - account = flask_login.current_user - - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) return {"result": "success"}, 200 +@console_ns.route("/apps//chat-messages") class ChatMessageApi(Resource): + @api.doc("create_chat_message") + @api.doc(description="Generate chat message for debugging") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "ChatMessageRequest", + { + "inputs": fields.Raw(required=True, description="Input variables"), + "query": fields.String(required=True, description="User query"), + "files": fields.List(fields.Raw(), description="Uploaded files"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "conversation_id": fields.String(description="Conversation ID"), + "parent_message_id": fields.String(description="Parent message ID"), + "response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"), + "retriever_from": fields.String(default="dev", description="Retriever source"), + }, + ) + ) + @api.response(200, "Chat message generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(404, "App or conversation not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): + if not isinstance(current_user, Account): + raise Forbidden() + + if not current_user.has_edit_permission: + raise Forbidden() + parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("query", type=str, required=True, location="json") @@ -123,11 +176,11 @@ class ChatMessageApi(Resource): if external_trace_id: args["external_trace_id"] = external_trace_id - account = flask_login.current_user - try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account or EndUser instance") response = AppGenerateService.generate( - app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -155,20 +208,19 @@ class ChatMessageApi(Resource): raise InternalServerError() +@console_ns.route("/apps//chat-messages//stop") class ChatMessageStopApi(Resource): + @api.doc("stop_chat_message") + @api.doc(description="Stop a running chat message generation") + @api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"}) + @api.response(200, "Task stopped successfully") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model, task_id): - account = flask_login.current_user - - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) return {"result": "success"}, 200 - - -api.add_resource(CompletionMessageApi, "/apps//completion-messages") -api.add_resource(CompletionMessageStopApi, "/apps//completion-messages//stop") -api.add_resource(ChatMessageApi, "/apps//chat-messages") -api.add_resource(ChatMessageStopApi, "/apps//chat-messages//stop") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index bc825effad..c0cbf6613e 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -8,7 +8,7 @@ from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload from werkzeug.exceptions import Forbidden, NotFound -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom @@ -22,13 +22,35 @@ from fields.conversation_fields import ( from libs.datetime_utils import naive_utc_now from libs.helper import DatetimeString from libs.login import login_required -from models import Conversation, EndUser, Message, MessageAnnotation +from models import Account, Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError +@console_ns.route("/apps//completion-conversations") class CompletionConversationApi(Resource): + @api.doc("list_completion_conversations") + @api.doc(description="Get completion conversations with pagination and filtering") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("keyword", type=str, location="args", help="Search keyword") + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + .add_argument( + "annotation_status", + type=str, + location="args", + choices=["annotated", "not_annotated", "all"], + default="all", + help="Annotation status filter", + ) + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)") + ) + @api.response(200, "Success", conversation_pagination_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -101,7 +123,14 @@ class CompletionConversationApi(Resource): return conversations +@console_ns.route("/apps//completion-conversations/") class CompletionConversationDetailApi(Resource): + @api.doc("get_completion_conversation") + @api.doc(description="Get completion conversation details with messages") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(200, "Success", conversation_message_detail_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @@ -114,6 +143,12 @@ class CompletionConversationDetailApi(Resource): return _get_conversation(app_model, conversation_id) + @api.doc("delete_completion_conversation") + @api.doc(description="Delete a completion conversation") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(204, "Conversation deleted successfully") + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @@ -124,6 +159,8 @@ class CompletionConversationDetailApi(Resource): conversation_id = str(conversation_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -131,7 +168,38 @@ class CompletionConversationDetailApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//chat-conversations") class ChatConversationApi(Resource): + @api.doc("list_chat_conversations") + @api.doc(description="Get chat conversations with pagination, filtering and summary") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("keyword", type=str, location="args", help="Search keyword") + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + .add_argument( + "annotation_status", + type=str, + location="args", + choices=["annotated", "not_annotated", "all"], + default="all", + help="Annotation status filter", + ) + .add_argument("message_count_gte", type=int, location="args", help="Minimum message count") + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)") + .add_argument( + "sort_by", + type=str, + location="args", + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + default="-updated_at", + help="Sort field and direction", + ) + ) + @api.response(200, "Success", conversation_with_summary_pagination_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -239,7 +307,7 @@ class ChatConversationApi(Resource): .having(func.count(Message.id) >= args["message_count_gte"]) ) - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) match args["sort_by"]: @@ -259,7 +327,14 @@ class ChatConversationApi(Resource): return conversations +@console_ns.route("/apps//chat-conversations/") class ChatConversationDetailApi(Resource): + @api.doc("get_chat_conversation") + @api.doc(description="Get chat conversation details") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(200, "Success", conversation_detail_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @@ -272,6 +347,12 @@ class ChatConversationDetailApi(Resource): return _get_conversation(app_model, conversation_id) + @api.doc("delete_chat_conversation") + @api.doc(description="Delete a chat conversation") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(204, "Conversation deleted successfully") + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @@ -282,6 +363,8 @@ class ChatConversationDetailApi(Resource): conversation_id = str(conversation_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -289,12 +372,6 @@ class ChatConversationDetailApi(Resource): return {"result": "success"}, 204 -api.add_resource(CompletionConversationApi, "/apps//completion-conversations") -api.add_resource(CompletionConversationDetailApi, "/apps//completion-conversations/") -api.add_resource(ChatConversationApi, "/apps//chat-conversations") -api.add_resource(ChatConversationDetailApi, "/apps//chat-conversations/") - - def _get_conversation(app_model, conversation_id): conversation = ( db.session.query(Conversation) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 5ca4c33f87..8a65a89963 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -2,7 +2,7 @@ from flask_restx import Resource, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db @@ -12,7 +12,17 @@ from models import ConversationVariable from models.model import AppMode +@console_ns.route("/apps//conversation-variables") class ConversationVariablesApi(Resource): + @api.doc("get_conversation_variables") + @api.doc(description="Get conversation variables for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser().add_argument( + "conversation_id", type=str, location="args", help="Conversation ID to filter variables" + ) + ) + @api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields) @setup_required @login_required @account_initialization_required @@ -55,6 +65,3 @@ class ConversationVariablesApi(Resource): for row in rows ], } - - -api.add_resource(ConversationVariablesApi, "/apps//conversation-variables") diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index a2cb226014..d911b25028 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,9 +1,9 @@ from collections.abc import Sequence from flask_login import current_user -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -19,7 +19,23 @@ from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required +@console_ns.route("/rule-generate") class RuleGenerateApi(Resource): + @api.doc("generate_rule_config") + @api.doc(description="Generate rule configuration using LLM") + @api.expect( + api.model( + "RuleGenerateRequest", + { + "instruction": fields.String(required=True, description="Rule generation instruction"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"), + }, + ) + ) + @api.response(200, "Rule configuration generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -50,7 +66,26 @@ class RuleGenerateApi(Resource): return rules +@console_ns.route("/rule-code-generate") class RuleCodeGenerateApi(Resource): + @api.doc("generate_rule_code") + @api.doc(description="Generate code rules using LLM") + @api.expect( + api.model( + "RuleCodeGenerateRequest", + { + "instruction": fields.String(required=True, description="Code generation instruction"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"), + "code_language": fields.String( + default="javascript", description="Programming language for code generation" + ), + }, + ) + ) + @api.response(200, "Code rules generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -82,7 +117,22 @@ class RuleCodeGenerateApi(Resource): return code_result +@console_ns.route("/rule-structured-output-generate") class RuleStructuredOutputGenerateApi(Resource): + @api.doc("generate_structured_output") + @api.doc(description="Generate structured output rules using LLM") + @api.expect( + api.model( + "StructuredOutputGenerateRequest", + { + "instruction": fields.String(required=True, description="Structured output generation instruction"), + "model_config": fields.Raw(required=True, description="Model configuration"), + }, + ) + ) + @api.response(200, "Structured output generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -111,7 +161,27 @@ class RuleStructuredOutputGenerateApi(Resource): return structured_output +@console_ns.route("/instruction-generate") class InstructionGenerateApi(Resource): + @api.doc("generate_instruction") + @api.doc(description="Generate instruction for workflow nodes or general use") + @api.expect( + api.model( + "InstructionGenerateRequest", + { + "flow_id": fields.String(required=True, description="Workflow/Flow ID"), + "node_id": fields.String(description="Node ID for workflow context"), + "current": fields.String(description="Current instruction text"), + "language": fields.String(default="javascript", description="Programming language (javascript/python)"), + "instruction": fields.String(required=True, description="Instruction for generation"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "ideal_output": fields.String(description="Expected ideal output"), + }, + ) + ) + @api.response(200, "Instruction generated successfully") + @api.response(400, "Invalid request parameters or flow/workflow not found") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -203,7 +273,21 @@ class InstructionGenerateApi(Resource): raise CompletionRequestError(e.description) +@console_ns.route("/instruction-generate/template") class InstructionGenerationTemplateApi(Resource): + @api.doc("get_instruction_template") + @api.doc(description="Get instruction generation template") + @api.expect( + api.model( + "InstructionTemplateRequest", + { + "instruction": fields.String(required=True, description="Template instruction"), + "ideal_output": fields.String(description="Expected ideal output"), + }, + ) + ) + @api.response(200, "Template retrieved successfully") + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -222,10 +306,3 @@ class InstructionGenerationTemplateApi(Resource): return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE} case _: raise ValueError(f"Invalid type: {args['type']}") - - -api.add_resource(RuleGenerateApi, "/rule-generate") -api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") -api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate") -api.add_resource(InstructionGenerateApi, "/instruction-generate") -api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template") diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 541803e539..b9a383ee61 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,10 +2,10 @@ import json from enum import StrEnum from flask_login import current_user -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from werkzeug.exceptions import NotFound -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db @@ -19,7 +19,12 @@ class AppMCPServerStatus(StrEnum): INACTIVE = "inactive" +@console_ns.route("/apps//server") class AppMCPServerController(Resource): + @api.doc("get_app_mcp_server") + @api.doc(description="Get MCP server configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "MCP server configuration retrieved successfully", app_server_fields) @setup_required @login_required @account_initialization_required @@ -29,6 +34,20 @@ class AppMCPServerController(Resource): server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() return server + @api.doc("create_app_mcp_server") + @api.doc(description="Create MCP server configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MCPServerCreateRequest", + { + "description": fields.String(description="Server description"), + "parameters": fields.Raw(required=True, description="Server parameters configuration"), + }, + ) + ) + @api.response(201, "MCP server configuration created successfully", app_server_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -59,6 +78,23 @@ class AppMCPServerController(Resource): db.session.commit() return server + @api.doc("update_app_mcp_server") + @api.doc(description="Update MCP server configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MCPServerUpdateRequest", + { + "id": fields.String(required=True, description="Server ID"), + "description": fields.String(description="Server description"), + "parameters": fields.Raw(required=True, description="Server parameters configuration"), + "status": fields.String(description="Server status"), + }, + ) + ) + @api.response(200, "MCP server configuration updated successfully", app_server_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Server not found") @setup_required @login_required @account_initialization_required @@ -94,7 +130,14 @@ class AppMCPServerController(Resource): return server +@console_ns.route("/apps//server/refresh") class AppMCPServerRefreshController(Resource): + @api.doc("refresh_app_mcp_server") + @api.doc(description="Refresh MCP server configuration and regenerate server code") + @api.doc(params={"server_id": "Server ID"}) + @api.response(200, "MCP server refreshed successfully", app_server_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Server not found") @setup_required @login_required @account_initialization_required @@ -113,7 +156,3 @@ class AppMCPServerRefreshController(Resource): server.server_code = AppMCPServer.generate_server_code(16) db.session.commit() return server - - -api.add_resource(AppMCPServerController, "/apps//server") -api.add_resource(AppMCPServerRefreshController, "/apps//server/refresh") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index f0605a37f9..3bd9c53a85 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,12 +1,11 @@ import logging -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy import exists, select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -27,7 +26,8 @@ from extensions.ext_database import db from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination -from libs.login import login_required +from libs.login import current_user, login_required +from models.account import Account from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.annotation_service import AppAnnotationService from services.errors.conversation import ConversationNotExistsError @@ -37,6 +37,7 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) +@console_ns.route("/apps//chat-messages") class ChatMessageListApi(Resource): message_infinite_scroll_pagination_fields = { "limit": fields.Integer, @@ -44,6 +45,17 @@ class ChatMessageListApi(Resource): "data": fields.List(fields.Nested(message_detail_fields)), } + @api.doc("list_chat_messages") + @api.doc(description="Get chat messages for a conversation with pagination") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID") + .add_argument("first_id", type=str, location="args", help="First message ID for pagination") + .add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)") + ) + @api.response(200, "Success", message_infinite_scroll_pagination_fields) + @api.response(404, "Conversation not found") @setup_required @login_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @@ -117,12 +129,31 @@ class ChatMessageListApi(Resource): return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) +@console_ns.route("/apps//feedbacks") class MessageFeedbackApi(Resource): + @api.doc("create_message_feedback") + @api.doc(description="Create or update message feedback (like/dislike)") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MessageFeedbackRequest", + { + "message_id": fields.String(required=True, description="Message ID"), + "rating": fields.String(enum=["like", "dislike"], description="Feedback rating"), + }, + ) + ) + @api.response(200, "Feedback updated successfully") + @api.response(404, "Message not found") + @api.response(403, "Insufficient permissions") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def post(self, app_model): + if current_user is None: + raise Forbidden() + parser = reqparse.RequestParser() parser.add_argument("message_id", required=True, type=uuid_value, location="json") parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") @@ -159,7 +190,24 @@ class MessageFeedbackApi(Resource): return {"result": "success"} +@console_ns.route("/apps//annotations") class MessageAnnotationApi(Resource): + @api.doc("create_message_annotation") + @api.doc(description="Create message annotation") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MessageAnnotationRequest", + { + "message_id": fields.String(description="Message ID"), + "question": fields.String(required=True, description="Question text"), + "answer": fields.String(required=True, description="Answer text"), + "annotation_reply": fields.Raw(description="Annotation reply"), + }, + ) + ) + @api.response(200, "Annotation created successfully", annotation_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -167,7 +215,9 @@ class MessageAnnotationApi(Resource): @get_app_model @marshal_with(annotation_fields) def post(self, app_model): - if not current_user.is_editor: + if not isinstance(current_user, Account): + raise Forbidden() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -181,18 +231,37 @@ class MessageAnnotationApi(Resource): return annotation +@console_ns.route("/apps//annotations/count") class MessageAnnotationCountApi(Resource): + @api.doc("get_annotation_count") + @api.doc(description="Get count of message annotations for the app") + @api.doc(params={"app_id": "Application ID"}) + @api.response( + 200, + "Annotation count retrieved successfully", + api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() return {"count": count} +@console_ns.route("/apps//chat-messages//suggested-questions") class MessageSuggestedQuestionApi(Resource): + @api.doc("get_message_suggested_questions") + @api.doc(description="Get suggested questions for a message") + @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) + @api.response( + 200, + "Suggested questions retrieved successfully", + api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}), + ) + @api.response(404, "Message or conversation not found") @setup_required @login_required @account_initialization_required @@ -225,7 +294,13 @@ class MessageSuggestedQuestionApi(Resource): return {"data": questions} +@console_ns.route("/apps//messages/") class MessageApi(Resource): + @api.doc("get_message") + @api.doc(description="Get message details by ID") + @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) + @api.response(200, "Message retrieved successfully", message_detail_fields) + @api.response(404, "Message not found") @setup_required @login_required @account_initialization_required @@ -240,11 +315,3 @@ class MessageApi(Resource): raise NotFound("Message Not Exists.") return message - - -api.add_resource(MessageSuggestedQuestionApi, "/apps//chat-messages//suggested-questions") -api.add_resource(ChatMessageListApi, "/apps//chat-messages", endpoint="console_chat_messages") -api.add_resource(MessageFeedbackApi, "/apps//feedbacks") -api.add_resource(MessageAnnotationApi, "/apps//annotations") -api.add_resource(MessageAnnotationCountApi, "/apps//annotations/count") -api.add_resource(MessageApi, "/apps//messages/", endpoint="console_message") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 52ff9b923d..11df511840 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -3,9 +3,10 @@ from typing import cast from flask import request from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields +from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.agent.entities import AgentToolEntity @@ -14,17 +15,51 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.login import login_required +from models.account import Account from models.model import AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService +@console_ns.route("/apps//model-config") class ModelConfigResource(Resource): + @api.doc("update_app_model_config") + @api.doc(description="Update application model configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "ModelConfigRequest", + { + "provider": fields.String(description="Model provider"), + "model": fields.String(description="Model name"), + "configs": fields.Raw(description="Model configuration parameters"), + "opening_statement": fields.String(description="Opening statement"), + "suggested_questions": fields.List(fields.String(), description="Suggested questions"), + "more_like_this": fields.Raw(description="More like this configuration"), + "speech_to_text": fields.Raw(description="Speech to text configuration"), + "text_to_speech": fields.Raw(description="Text to speech configuration"), + "retrieval_model": fields.Raw(description="Retrieval model configuration"), + "tools": fields.List(fields.Raw(), description="Available tools"), + "dataset_configs": fields.Raw(description="Dataset configurations"), + "agent_mode": fields.Raw(description="Agent mode configuration"), + }, + ) + ) + @api.response(200, "Model configuration updated successfully") + @api.response(400, "Invalid configuration") + @api.response(404, "App not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): """Modify app model config""" + if not isinstance(current_user, Account): + raise Forbidden() + + if not current_user.has_edit_permission: + raise Forbidden() + + assert current_user.current_tenant_id is not None, "The tenant information should be loaded." # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, @@ -39,7 +74,7 @@ class ModelConfigResource(Resource): ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) - if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: # get original app model config original_app_model_config = ( db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() @@ -142,6 +177,3 @@ class ModelConfigResource(Resource): app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) return {"result": "success"} - - -api.add_resource(ModelConfigResource, "/apps//model-config") diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 74c2867c2f..981974e842 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,18 +1,31 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from werkzeug.exceptions import BadRequest -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.ops_service import OpsService +@console_ns.route("/apps//trace-config") class TraceAppConfigApi(Resource): """ Manage trace app configurations """ + @api.doc("get_trace_app_config") + @api.doc(description="Get tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser().add_argument( + "tracing_provider", type=str, required=True, location="args", help="Tracing provider name" + ) + ) + @api.response( + 200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data") + ) + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -29,6 +42,22 @@ class TraceAppConfigApi(Resource): except Exception as e: raise BadRequest(str(e)) + @api.doc("create_trace_app_config") + @api.doc(description="Create a new tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "TraceConfigCreateRequest", + { + "tracing_provider": fields.String(required=True, description="Tracing provider name"), + "tracing_config": fields.Raw(required=True, description="Tracing configuration data"), + }, + ) + ) + @api.response( + 201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data") + ) + @api.response(400, "Invalid request parameters or configuration already exists") @setup_required @login_required @account_initialization_required @@ -51,6 +80,20 @@ class TraceAppConfigApi(Resource): except Exception as e: raise BadRequest(str(e)) + @api.doc("update_trace_app_config") + @api.doc(description="Update an existing tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "TraceConfigUpdateRequest", + { + "tracing_provider": fields.String(required=True, description="Tracing provider name"), + "tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"), + }, + ) + ) + @api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response")) + @api.response(400, "Invalid request parameters or configuration not found") @setup_required @login_required @account_initialization_required @@ -71,6 +114,16 @@ class TraceAppConfigApi(Resource): except Exception as e: raise BadRequest(str(e)) + @api.doc("delete_trace_app_config") + @api.doc(description="Delete an existing tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser().add_argument( + "tracing_provider", type=str, required=True, location="args", help="Tracing provider name" + ) + ) + @api.response(204, "Tracing configuration deleted successfully") + @api.response(400, "Invalid request parameters or configuration not found") @setup_required @login_required @account_initialization_required @@ -87,6 +140,3 @@ class TraceAppConfigApi(Resource): return {"result": "success"}, 204 except Exception as e: raise BadRequest(str(e)) - - -api.add_resource(TraceAppConfigApi, "/apps//trace-config") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 778ce92da6..95befc5df9 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,16 +1,16 @@ from flask_login import current_user -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.datetime_utils import naive_utc_now from libs.login import login_required -from models import Site +from models import Account, Site def parse_app_site_args(): @@ -36,7 +36,39 @@ def parse_app_site_args(): return parser.parse_args() +@console_ns.route("/apps//site") class AppSite(Resource): + @api.doc("update_app_site") + @api.doc(description="Update application site configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppSiteRequest", + { + "title": fields.String(description="Site title"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + "description": fields.String(description="Site description"), + "default_language": fields.String(description="Default language"), + "chat_color_theme": fields.String(description="Chat color theme"), + "chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"), + "customize_domain": fields.String(description="Custom domain"), + "copyright": fields.String(description="Copyright text"), + "privacy_policy": fields.String(description="Privacy policy"), + "custom_disclaimer": fields.String(description="Custom disclaimer"), + "customize_token_strategy": fields.String( + enum=["must", "allow", "not_allow"], description="Token strategy" + ), + "prompt_public": fields.Boolean(description="Make prompt public"), + "show_workflow_steps": fields.Boolean(description="Show workflow steps"), + "use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"), + }, + ) + ) + @api.response(200, "Site configuration updated successfully", app_site_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "App not found") @setup_required @login_required @account_initialization_required @@ -75,6 +107,8 @@ class AppSite(Resource): if value is not None: setattr(site, attr_name, value) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") site.updated_by = current_user.id site.updated_at = naive_utc_now() db.session.commit() @@ -82,7 +116,14 @@ class AppSite(Resource): return site +@console_ns.route("/apps//site/access-token-reset") class AppSiteAccessTokenReset(Resource): + @api.doc("reset_app_site_access_token") + @api.doc(description="Reset access token for application site") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Access token reset successfully", app_site_fields) + @api.response(403, "Insufficient permissions (admin/owner required)") + @api.response(404, "App or site not found") @setup_required @login_required @account_initialization_required @@ -99,12 +140,10 @@ class AppSiteAccessTokenReset(Resource): raise NotFound site.code = Site.generate_code(16) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") site.updated_by = current_user.id site.updated_at = naive_utc_now() db.session.commit() return site - - -api.add_resource(AppSite, "/apps//site") -api.add_resource(AppSiteAccessTokenReset, "/apps//site/access-token-reset") diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 27e405af38..6894458578 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -5,9 +5,9 @@ import pytz import sqlalchemy as sa from flask import jsonify from flask_login import current_user -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom @@ -17,11 +17,25 @@ from libs.login import login_required from models import AppMode, Message +@console_ns.route("/apps//statistics/daily-messages") class DailyMessageStatistic(Resource): + @api.doc("get_daily_message_statistics") + @api.doc(description="Get daily message statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily message statistics retrieved successfully", + fields.List(fields.Raw(description="Daily message count data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -74,11 +88,25 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/daily-conversations") class DailyConversationStatistic(Resource): + @api.doc("get_daily_conversation_statistics") + @api.doc(description="Get daily conversation statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily conversation statistics retrieved successfully", + fields.List(fields.Raw(description="Daily conversation count data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -126,11 +154,25 @@ class DailyConversationStatistic(Resource): return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/daily-end-users") class DailyTerminalsStatistic(Resource): + @api.doc("get_daily_terminals_statistics") + @api.doc(description="Get daily terminal/end-user statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily terminal statistics retrieved successfully", + fields.List(fields.Raw(description="Daily terminal count data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -183,11 +225,25 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/token-costs") class DailyTokenCostStatistic(Resource): + @api.doc("get_daily_token_cost_statistics") + @api.doc(description="Get daily token cost statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily token cost statistics retrieved successfully", + fields.List(fields.Raw(description="Daily token cost data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -243,7 +299,21 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/average-session-interactions") class AverageSessionInteractionStatistic(Resource): + @api.doc("get_average_session_interaction_statistics") + @api.doc(description="Get average session interaction statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Average session interaction statistics retrieved successfully", + fields.List(fields.Raw(description="Average session interaction data")), + ) @setup_required @login_required @account_initialization_required @@ -319,11 +389,25 @@ ORDER BY return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/user-satisfaction-rate") class UserSatisfactionRateStatistic(Resource): + @api.doc("get_user_satisfaction_rate_statistics") + @api.doc(description="Get user satisfaction rate statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "User satisfaction rate statistics retrieved successfully", + fields.List(fields.Raw(description="User satisfaction rate data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -385,7 +469,21 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/average-response-time") class AverageResponseTimeStatistic(Resource): + @api.doc("get_average_response_time_statistics") + @api.doc(description="Get average response time statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Average response time statistics retrieved successfully", + fields.List(fields.Raw(description="Average response time data")), + ) @setup_required @login_required @account_initialization_required @@ -442,11 +540,25 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/tokens-per-second") class TokensPerSecondStatistic(Resource): + @api.doc("get_tokens_per_second_statistics") + @api.doc(description="Get tokens per second statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Tokens per second statistics retrieved successfully", + fields.List(fields.Raw(description="Tokens per second data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -500,13 +612,3 @@ WHERE response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) return jsonify({"data": response_data}) - - -api.add_resource(DailyMessageStatistic, "/apps//statistics/daily-messages") -api.add_resource(DailyConversationStatistic, "/apps//statistics/daily-conversations") -api.add_resource(DailyTerminalsStatistic, "/apps//statistics/daily-end-users") -api.add_resource(DailyTokenCostStatistic, "/apps//statistics/token-costs") -api.add_resource(AverageSessionInteractionStatistic, "/apps//statistics/average-session-interactions") -api.add_resource(UserSatisfactionRateStatistic, "/apps//statistics/user-satisfaction-rate") -api.add_resource(AverageResponseTimeStatistic, "/apps//statistics/average-response-time") -api.add_resource(TokensPerSecondStatistic, "/apps//statistics/tokens-per-second") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 05178328fe..0967d09e9c 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -69,7 +69,7 @@ class DraftWorkflowApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() # fetch draft workflow by app_model @@ -92,7 +92,7 @@ class DraftWorkflowApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() content_type = request.headers.get("Content-Type", "") @@ -170,7 +170,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() if not isinstance(current_user, Account): @@ -220,7 +220,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -256,7 +256,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not isinstance(current_user, Account): raise Forbidden() - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -293,7 +293,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -330,7 +330,7 @@ class WorkflowDraftRunLoopNodeApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -367,7 +367,7 @@ class DraftWorkflowRunApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -406,7 +406,7 @@ class WorkflowTaskStopApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) @@ -428,7 +428,7 @@ class DraftWorkflowNodeRunApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -476,7 +476,7 @@ class PublishedWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() # fetch published workflow by app_model @@ -497,7 +497,7 @@ class PublishedWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -547,7 +547,7 @@ class DefaultBlockConfigsApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() # Get default block configs @@ -567,7 +567,7 @@ class DefaultBlockConfigApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -602,7 +602,7 @@ class ConvertToWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() if request.data: @@ -651,7 +651,7 @@ class PublishedAllWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -702,7 +702,7 @@ class WorkflowByIdApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # Check permission - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -715,7 +715,6 @@ class WorkflowByIdApi(Resource): raise ValueError("Marked name cannot exceed 20 characters") if args.marked_comment and len(args.marked_comment) > 100: raise ValueError("Marked comment cannot exceed 100 characters") - args = parser.parse_args() # Prepare update data update_data = {} @@ -758,7 +757,7 @@ class WorkflowByIdApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # Check permission - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() workflow_service = WorkflowService() diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 5fced3e90f..bdcd79f339 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -137,7 +137,7 @@ def _api_prerequisite(f): @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) def wrapper(*args, **kwargs): assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() return f(*args, **kwargs) diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 7cef175c14..da7216086e 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -18,10 +18,10 @@ from models.model import AppMode class WorkflowDailyRunsStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -80,10 +80,10 @@ WHERE class WorkflowDailyTerminalsStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -142,10 +142,10 @@ WHERE class WorkflowDailyTokenCostStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index c7e300279a..5a871f896a 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional, Union +from typing import Optional, ParamSpec, TypeVar, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db @@ -8,6 +8,9 @@ from libs.login import current_user from models import App, AppMode from models.account import Account +P = ParamSpec("P") +R = TypeVar("R") + def _load_app_model(app_id: str) -> Optional[App]: assert isinstance(current_user, Account) @@ -19,10 +22,10 @@ def _load_app_model(app_id: str) -> Optional[App]: return app_model -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): - def decorator(view_func): +def get_app_model(view: Optional[Callable[P, R]] = None, *, mode: Union[AppMode, list[AppMode], None] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): if not kwargs.get("app_id"): raise ValueError("missing app_id in path parameters") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index e82e403ec2..8cdadfb03c 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,8 +1,8 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from constants.languages import supported_language -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -10,14 +10,36 @@ from libs.helper import StrLen, email, extract_remote_ip, timezone from models.account import AccountStatus from services.account_service import AccountService, RegisterService +active_check_parser = reqparse.RequestParser() +active_check_parser.add_argument( + "workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID" +) +active_check_parser.add_argument( + "email", type=email, required=False, nullable=True, location="args", help="Email address" +) +active_check_parser.add_argument( + "token", type=str, required=True, nullable=False, location="args", help="Activation token" +) + +@console_ns.route("/activate/check") class ActivateCheckApi(Resource): + @api.doc("check_activation_token") + @api.doc(description="Check if activation token is valid") + @api.expect(active_check_parser) + @api.response( + 200, + "Success", + api.model( + "ActivationCheckResponse", + { + "is_valid": fields.Boolean(description="Whether token is valid"), + "data": fields.Raw(description="Activation data if valid"), + }, + ), + ) def get(self): - parser = reqparse.RequestParser() - parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args") - parser.add_argument("email", type=email, required=False, nullable=True, location="args") - parser.add_argument("token", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() + args = active_check_parser.parse_args() workspaceId = args["workspace_id"] reg_email = args["email"] @@ -38,18 +60,36 @@ class ActivateCheckApi(Resource): return {"is_valid": False} +active_parser = reqparse.RequestParser() +active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") +active_parser.add_argument("email", type=email, required=False, nullable=True, location="json") +active_parser.add_argument("token", type=str, required=True, nullable=False, location="json") +active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") +active_parser.add_argument( + "interface_language", type=supported_language, required=True, nullable=False, location="json" +) +active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") + + +@console_ns.route("/activate") class ActivateApi(Resource): + @api.doc("activate_account") + @api.doc(description="Activate account with invitation token") + @api.expect(active_parser) + @api.response( + 200, + "Account activated successfully", + api.model( + "ActivationResponse", + { + "result": fields.String(description="Operation result"), + "data": fields.Raw(description="Login token data"), + }, + ), + ) + @api.response(400, "Already activated or invalid token") def post(self): - parser = reqparse.RequestParser() - parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") - parser.add_argument("email", type=email, required=False, nullable=True, location="json") - parser.add_argument("token", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") - parser.add_argument( - "interface_language", type=supported_language, required=True, nullable=False, location="json" - ) - parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") - args = parser.parse_args() + args = active_parser.parse_args() invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) if invitation is None: @@ -70,7 +110,3 @@ class ActivateApi(Resource): token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) return {"result": "success", "data": token_pair.model_dump()} - - -api.add_resource(ActivateCheckApi, "/activate/check") -api.add_resource(ActivateApi, "/activate") diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 8f57b3d03e..fc4ba3a2c7 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -3,11 +3,11 @@ import logging import requests from flask import current_app, redirect, request from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.console import api +from controllers.console import api, console_ns from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -28,7 +28,21 @@ def get_oauth_providers(): return OAUTH_PROVIDERS +@console_ns.route("/oauth/data-source/") class OAuthDataSource(Resource): + @api.doc("oauth_data_source") + @api.doc(description="Get OAuth authorization URL for data source provider") + @api.doc(params={"provider": "Data source provider name (notion)"}) + @api.response( + 200, + "Authorization URL or internal setup success", + api.model( + "OAuthDataSourceResponse", + {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, + ), + ) + @api.response(400, "Invalid provider") + @api.response(403, "Admin privileges required") def get(self, provider: str): # The role of the current user in the table must be admin or owner if not current_user.is_admin_or_owner: @@ -49,7 +63,19 @@ class OAuthDataSource(Resource): return {"data": auth_url}, 200 +@console_ns.route("/oauth/data-source/callback/") class OAuthDataSourceCallback(Resource): + @api.doc("oauth_data_source_callback") + @api.doc(description="Handle OAuth callback from data source provider") + @api.doc( + params={ + "provider": "Data source provider name (notion)", + "code": "Authorization code from OAuth provider", + "error": "Error message from OAuth provider", + } + ) + @api.response(302, "Redirect to console with result") + @api.response(400, "Invalid provider") def get(self, provider: str): OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -68,7 +94,19 @@ class OAuthDataSourceCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") +@console_ns.route("/oauth/data-source/binding/") class OAuthDataSourceBinding(Resource): + @api.doc("oauth_data_source_binding") + @api.doc(description="Bind OAuth data source with authorization code") + @api.doc( + params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"} + ) + @api.response( + 200, + "Data source binding success", + api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid provider or code") def get(self, provider: str): OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -90,7 +128,17 @@ class OAuthDataSourceBinding(Resource): return {"result": "success"}, 200 +@console_ns.route("/oauth/data-source///sync") class OAuthDataSourceSync(Resource): + @api.doc("oauth_data_source_sync") + @api.doc(description="Sync data from OAuth data source") + @api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"}) + @api.response( + 200, + "Data source sync success", + api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid provider or sync failed") @setup_required @login_required @account_initialization_required @@ -111,9 +159,3 @@ class OAuthDataSourceSync(Resource): return {"error": "OAuth data source process failed"}, 400 return {"result": "success"}, 200 - - -api.add_resource(OAuthDataSource, "/oauth/data-source/") -api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") -api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") -api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py new file mode 100644 index 0000000000..91de19a78a --- /dev/null +++ b/api/controllers/console/auth/email_register.py @@ -0,0 +1,155 @@ +from flask import request +from flask_restx import Resource, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from constants.languages import languages +from controllers.console import api +from controllers.console.auth.error import ( + EmailAlreadyInUseError, + EmailCodeError, + EmailRegisterLimitError, + InvalidEmailError, + InvalidTokenError, + PasswordMismatchError, +) +from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError +from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required +from extensions.ext_database import db +from libs.helper import email, extract_remote_ip +from libs.password import valid_password +from models.account import Account +from services.account_service import AccountService +from services.billing_service import BillingService +from services.errors.account import AccountNotFoundError, AccountRegisterError + + +class EmailRegisterSendEmailApi(Resource): + @setup_required + @email_password_login_enabled + @email_register_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + language = "en-US" + if args["language"] in languages: + language = args["language"] + + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): + raise AccountInFreezeError() + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + token = None + token = AccountService.send_email_register_email(email=args["email"], account=account, language=language) + return {"result": "success", "data": token} + + +class EmailRegisterCheckApi(Resource): + @setup_required + @email_password_login_enabled + @email_register_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + user_email = args["email"] + + is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"]) + if is_email_register_error_rate_limit: + raise EmailRegisterLimitError() + + token_data = AccountService.get_email_register_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_email_register_error_rate_limit(args["email"]) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_email_register_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_email_register_token( + user_email, code=args["code"], additional_data={"phase": "register"} + ) + + AccountService.reset_email_register_error_rate_limit(args["email"]) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class EmailRegisterResetApi(Resource): + @setup_required + @email_password_login_enabled + @email_register_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + args = parser.parse_args() + + # Validate passwords match + if args["new_password"] != args["password_confirm"]: + raise PasswordMismatchError() + + # Validate token and get register data + register_data = AccountService.get_email_register_data(args["token"]) + if not register_data: + raise InvalidTokenError() + # Must use token in reset phase + if register_data.get("phase", "") != "register": + raise InvalidTokenError() + + # Revoke token to prevent reuse + AccountService.revoke_email_register_token(args["token"]) + + email = register_data.get("email", "") + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + + if account: + raise EmailAlreadyInUseError() + else: + account = self._create_new_account(email, args["password_confirm"]) + if not account: + raise AccountNotFoundError() + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(email) + + return {"result": "success", "data": token_pair.model_dump()} + + def _create_new_account(self, email, password) -> Account | None: + # Create new account if allowed + account = None + try: + account = AccountService.create_account_and_tenant( + email=email, + name=email, + password=password, + interface_language=languages[0], + ) + except AccountRegisterError: + raise AccountInFreezeError() + + return account + + +api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email") +api.add_resource(EmailRegisterCheckApi, "/email-register/validity") +api.add_resource(EmailRegisterResetApi, "/email-register") diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index 7853bef917..81f1c6e70f 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -27,21 +27,43 @@ class InvalidTokenError(BaseHTTPException): class PasswordResetRateLimitExceededError(BaseHTTPException): error_code = "password_reset_rate_limit_exceeded" - description = "Too many password reset emails have been sent. Please try again in 1 minute." + description = "Too many password reset emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + + +class EmailRegisterRateLimitExceededError(BaseHTTPException): + error_code = "email_register_rate_limit_exceeded" + description = "Too many email register emails have been sent. Please try again in {minutes} minutes." + code = 429 + + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailChangeRateLimitExceededError(BaseHTTPException): error_code = "email_change_rate_limit_exceeded" - description = "Too many email change emails have been sent. Please try again in 1 minute." + description = "Too many email change emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class OwnerTransferRateLimitExceededError(BaseHTTPException): error_code = "owner_transfer_rate_limit_exceeded" - description = "Too many owner transfer emails have been sent. Please try again in 1 minute." + description = "Too many owner transfer emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailCodeError(BaseHTTPException): error_code = "email_code_error" @@ -69,15 +91,23 @@ class EmailPasswordLoginLimitError(BaseHTTPException): class EmailCodeLoginRateLimitExceededError(BaseHTTPException): error_code = "email_code_login_rate_limit_exceeded" - description = "Too many login emails have been sent. Please try again in 5 minutes." + description = "Too many login emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 5): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException): error_code = "email_code_account_deletion_rate_limit_exceeded" - description = "Too many account deletion emails have been sent. Please try again in 5 minutes." + description = "Too many account deletion emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 5): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailPasswordResetLimitError(BaseHTTPException): error_code = "email_password_reset_limit" @@ -85,6 +115,12 @@ class EmailPasswordResetLimitError(BaseHTTPException): code = 429 +class EmailRegisterLimitError(BaseHTTPException): + error_code = "email_register_limit" + description = "Too many failed email register attempts. Please try again in 24 hours." + code = 429 + + class EmailChangeLimitError(BaseHTTPException): error_code = "email_change_limit" description = "Too many failed email change attempts. Please try again in 24 hours." diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index ede0696854..36ccb1d562 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,12 +2,11 @@ import base64 import secrets from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session -from constants.languages import languages -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.auth.error import ( EmailCodeError, EmailPasswordResetLimitError, @@ -15,7 +14,7 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) -from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError +from controllers.console.error import AccountNotFound, EmailSendIpLimitError from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db @@ -23,12 +22,35 @@ from libs.helper import email, extract_remote_ip from libs.password import hash_password, valid_password from models.account import Account from services.account_service import AccountService, TenantService -from services.errors.account import AccountRegisterError -from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService +@console_ns.route("/forgot-password") class ForgotPasswordSendEmailApi(Resource): + @api.doc("send_forgot_password_email") + @api.doc(description="Send password reset email") + @api.expect( + api.model( + "ForgotPasswordEmailRequest", + { + "email": fields.String(required=True, description="Email address"), + "language": fields.String(description="Language for email (zh-Hans/en-US)"), + }, + ) + ) + @api.response( + 200, + "Email sent successfully", + api.model( + "ForgotPasswordEmailResponse", + { + "result": fields.String(description="Operation result"), + "data": fields.String(description="Reset token"), + "code": fields.String(description="Error code if account not found"), + }, + ), + ) + @api.response(400, "Invalid email or rate limit exceeded") @setup_required @email_password_login_enabled def post(self): @@ -48,20 +70,44 @@ class ForgotPasswordSendEmailApi(Resource): with Session(db.engine) as session: account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() - token = None - if account is None: - if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_reset_password_email(email=args["email"], language=language) - return {"result": "fail", "data": token, "code": "account_not_found"} - else: - raise AccountNotFound() - else: - token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) + + token = AccountService.send_reset_password_email( + account=account, + email=args["email"], + language=language, + is_allow_register=FeatureService.get_system_features().is_allow_register, + ) return {"result": "success", "data": token} +@console_ns.route("/forgot-password/validity") class ForgotPasswordCheckApi(Resource): + @api.doc("check_forgot_password_code") + @api.doc(description="Verify password reset code") + @api.expect( + api.model( + "ForgotPasswordCheckRequest", + { + "email": fields.String(required=True, description="Email address"), + "code": fields.String(required=True, description="Verification code"), + "token": fields.String(required=True, description="Reset token"), + }, + ) + ) + @api.response( + 200, + "Code verified successfully", + api.model( + "ForgotPasswordCheckResponse", + { + "is_valid": fields.Boolean(description="Whether code is valid"), + "email": fields.String(description="Email address"), + "token": fields.String(description="New reset token"), + }, + ), + ) + @api.response(400, "Invalid code or token") @setup_required @email_password_login_enabled def post(self): @@ -100,7 +146,26 @@ class ForgotPasswordCheckApi(Resource): return {"is_valid": True, "email": token_data.get("email"), "token": new_token} +@console_ns.route("/forgot-password/resets") class ForgotPasswordResetApi(Resource): + @api.doc("reset_password") + @api.doc(description="Reset password with verification token") + @api.expect( + api.model( + "ForgotPasswordResetRequest", + { + "token": fields.String(required=True, description="Verification token"), + "new_password": fields.String(required=True, description="New password"), + "password_confirm": fields.String(required=True, description="Password confirmation"), + }, + ) + ) + @api.response( + 200, + "Password reset successfully", + api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid token or password mismatch") @setup_required @email_password_login_enabled def post(self): @@ -137,7 +202,7 @@ class ForgotPasswordResetApi(Resource): if account: self._update_existing_account(account, password_hashed, salt, session) else: - self._create_new_account(email, args["password_confirm"]) + raise AccountNotFound() return {"result": "success"} @@ -157,22 +222,6 @@ class ForgotPasswordResetApi(Resource): account.current_tenant = tenant tenant_was_created.send(tenant) - def _create_new_account(self, email, password): - # Create new account if allowed - try: - AccountService.create_account_and_tenant( - email=email, - name=email, - password=password, - interface_language=languages[0], - ) - except WorkSpaceNotAllowedCreateError: - pass - except WorkspacesLimitExceededError: - pass - except AccountRegisterError: - raise AccountInFreezeError() - api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index b11bc0c6ac..3b35ab3c23 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -26,7 +26,6 @@ from controllers.console.error import ( from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip -from libs.password import valid_password from models.account import Account from services.account_service import AccountService, RegisterService, TenantService from services.billing_service import BillingService @@ -44,10 +43,9 @@ class LoginApi(Resource): """Authenticate user and login.""" parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("password", type=valid_password, required=True, location="json") + parser.add_argument("password", type=str, required=True, location="json") parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") parser.add_argument("invite_token", type=str, required=False, default=None, location="json") - parser.add_argument("language", type=str, required=False, default="en-US", location="json") args = parser.parse_args() if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): @@ -61,11 +59,6 @@ class LoginApi(Resource): if invitation: invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) - if args["language"] is not None and args["language"] == "zh-Hans": - language = "zh-Hans" - else: - language = "en-US" - try: if invitation: data = invitation.get("data", {}) @@ -80,12 +73,6 @@ class LoginApi(Resource): except services.errors.account.AccountPasswordError: AccountService.add_login_error_rate_limit(args["email"]) raise AuthenticationFailedError() - except services.errors.account.AccountNotFoundError: - if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_reset_password_email(email=args["email"], language=language) - return {"result": "fail", "data": token, "code": "account_not_found"} - else: - raise AccountNotFound() # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: @@ -133,13 +120,12 @@ class ResetPasswordSendEmailApi(Resource): except AccountRegisterError: raise AccountInFreezeError() - if account is None: - if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_reset_password_email(email=args["email"], language=language) - else: - raise AccountNotFound() - else: - token = AccountService.send_reset_password_email(account=account, language=language) + token = AccountService.send_reset_password_email( + email=args["email"], + account=account, + language=language, + is_allow_register=FeatureService.get_system_features().is_allow_register, + ) return {"result": "success", "data": token} diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 332a98c474..4b94cc0eb6 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -18,11 +18,12 @@ from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models import Account from models.account import AccountStatus from services.account_service import AccountService, RegisterService, TenantService +from services.billing_service import BillingService from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.feature_service import FeatureService -from .. import api +from .. import api, console_ns logger = logging.getLogger(__name__) @@ -50,7 +51,13 @@ def get_oauth_providers(): return OAUTH_PROVIDERS +@console_ns.route("/oauth/login/") class OAuthLogin(Resource): + @api.doc("oauth_login") + @api.doc(description="Initiate OAuth login process") + @api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"}) + @api.response(302, "Redirect to OAuth authorization URL") + @api.response(400, "Invalid provider") def get(self, provider: str): invite_token = request.args.get("invite_token") or None OAUTH_PROVIDERS = get_oauth_providers() @@ -63,7 +70,19 @@ class OAuthLogin(Resource): return redirect(auth_url) +@console_ns.route("/oauth/authorize/") class OAuthCallback(Resource): + @api.doc("oauth_callback") + @api.doc(description="Handle OAuth callback and complete login process") + @api.doc( + params={ + "provider": "OAuth provider name (github/google)", + "code": "Authorization code from OAuth provider", + "state": "Optional state parameter (used for invite token)", + } + ) + @api.response(302, "Redirect to console with access token") + @api.response(400, "OAuth process failed") def get(self, provider: str): OAUTH_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -77,6 +96,9 @@ class OAuthCallback(Resource): if state: invite_token = state + if not code: + return {"error": "Authorization code is required"}, 400 + try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) @@ -86,7 +108,7 @@ class OAuthCallback(Resource): return {"error": "OAuth process failed"}, 400 if invite_token and RegisterService.is_valid_invite_token(invite_token): - invitation = RegisterService._get_invitation_by_token(token=invite_token) + invitation = RegisterService.get_invitation_by_token(token=invite_token) if invitation: invitation_email = invitation.get("email", None) if invitation_email != user_info.email: @@ -162,7 +184,15 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): if not account: if not FeatureService.get_system_features().is_allow_register: - raise AccountNotFoundError() + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email): + raise AccountRegisterError( + description=( + "This email account has been deleted within the past " + "30 days and is temporarily unavailable for new account registration" + ) + ) + else: + raise AccountRegisterError(description=("Invalid email or password")) account_name = user_info.name or "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider @@ -181,7 +211,3 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): AccountService.link_account_integrate(provider, user_info.id, account) return account - - -api.add_resource(OAuthLogin, "/oauth/login/") -api.add_resource(OAuthCallback, "/oauth/authorize/") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index e4d5f1be6e..6e49bfa510 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -29,14 +29,12 @@ class DataSourceApi(Resource): @marshal_with(integrate_list_fields) def get(self): # get workspace data source integrates - data_source_integrates = ( - db.session.query(DataSourceOauthBinding) - .where( + data_source_integrates = db.session.scalars( + select(DataSourceOauthBinding).where( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.disabled == False, ) - .all() - ) + ).all() base_url = request.url_root.rstrip("/") data_source_oauth_base_path = "/console/api/oauth/data-source" @@ -249,7 +247,7 @@ class DataSourceNotionDatasetSyncApi(Resource): documents = DocumentService.get_document_by_dataset_id(dataset_id_str) for document in documents: document_indexing_sync_task.delay(dataset_id_str, document.id) - return 200 + return {"result": "success"}, 200 class DataSourceNotionDocumentSyncApi(Resource): @@ -267,7 +265,7 @@ class DataSourceNotionDocumentSyncApi(Resource): if document is None: raise NotFound("Document not found.") document_indexing_sync_task.delay(dataset_id_str, document_id_str) - return 200 + return {"result": "success"}, 200 api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates//") diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 11b7b1fec0..9fb092607a 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -2,6 +2,7 @@ import flask_restx from flask import request from flask_login import current_user from flask_restx import Resource, marshal, marshal_with, reqparse +from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services @@ -411,11 +412,11 @@ class DatasetIndexingEstimateApi(Resource): extract_settings = [] if args["info_list"]["data_source_type"] == "upload_file": file_ids = args["info_list"]["file_info_list"]["file_ids"] - file_details = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) - .all() - ) + file_details = db.session.scalars( + select(UploadFile).where( + UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids) + ) + ).all() if file_details is None: raise NotFound("File not found.") @@ -518,11 +519,11 @@ class DatasetIndexingStatusApi(Resource): @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - documents = ( - db.session.query(Document) - .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) - .all() - ) + documents = db.session.scalars( + select(Document).where( + Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id + ) + ).all() documents_status = [] for document in documents: completed_segments = ( @@ -569,11 +570,11 @@ class DatasetApiKeyApi(Resource): @account_initialization_required @marshal_with(api_key_list) def get(self): - keys = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) - .all() - ) + keys = db.session.scalars( + select(ApiToken).where( + ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id + ) + ).all() return {"items": keys} @setup_required diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index c9c0b6a5ce..f943fb3ccb 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,5 +1,6 @@ import logging from argparse import ArgumentTypeError +from collections.abc import Sequence from typing import Literal, cast from flask import request @@ -79,7 +80,7 @@ class DocumentResource(Resource): return document - def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: + def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 6aa309f930..21ab5e4fe1 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -113,7 +113,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): MetadataService.enable_built_in_field(dataset) elif action == "disable": MetadataService.disable_built_in_field(dataset) - return 200 + return {"result": "success"}, 200 class DocumentMetadataEditApi(Resource): @@ -135,7 +135,7 @@ class DocumentMetadataEditApi(Resource): MetadataService.update_documents_metadata(dataset, metadata_args) - return 200 + return {"result": "success"}, 200 api.add_resource(DatasetMetadataCreateApi, "/datasets//metadata") diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index cc46f54ea3..a99708b7cd 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -28,6 +27,8 @@ from extensions.ext_database import db from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -57,6 +58,8 @@ class CompletionApi(InstalledAppResource): db.session.commit() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) @@ -90,6 +93,8 @@ class CompletionStopApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 @@ -117,6 +122,8 @@ class ChatApi(InstalledAppResource): db.session.commit() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) @@ -153,6 +160,8 @@ class ChatStopApi(InstalledAppResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 43ad3ecfbd..1aef9c544d 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy.orm import Session @@ -10,6 +9,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError @@ -35,6 +36,8 @@ class ConversationListApi(InstalledAppResource): pinned = args["pinned"] == "true" try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") with Session(db.engine) as session: return WebConversationService.pagination_by_last_id( session=session, @@ -58,6 +61,8 @@ class ConversationApi(InstalledAppResource): conversation_id = str(c_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -81,6 +86,8 @@ class ConversationRenameApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return ConversationService.rename( app_model, conversation_id, current_user, args["name"], args["auto_generate"] ) @@ -98,6 +105,8 @@ class ConversationPinApi(InstalledAppResource): conversation_id = str(c_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") WebConversationService.pin(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -113,6 +122,8 @@ class ConversationUnPinApi(InstalledAppResource): raise NotChatAppError() conversation_id = str(c_id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") WebConversationService.unpin(app_model, conversation_id, current_user) return {"result": "success"} diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3ccedd654b..bdc3fb0dbd 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,9 +2,8 @@ import logging from typing import Any from flask import request -from flask_login import current_user from flask_restx import Resource, inputs, marshal_with, reqparse -from sqlalchemy import and_ +from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound from controllers.console import api @@ -13,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields from libs.datetime_utils import naive_utc_now -from libs.login import login_required -from models import App, InstalledApp, RecommendedApp +from libs.login import current_user, login_required +from models import Account, App, InstalledApp, RecommendedApp from services.account_service import TenantService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService @@ -29,17 +28,23 @@ class InstalledAppsListApi(Resource): @marshal_with(installed_app_list_fields) def get(self): app_id = request.args.get("app_id", default=None, type=str) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") current_tenant_id = current_user.current_tenant_id if app_id: - installed_apps = ( - db.session.query(InstalledApp) - .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) - .all() - ) + installed_apps = db.session.scalars( + select(InstalledApp).where( + and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id) + ) + ).all() else: - installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() + installed_apps = db.session.scalars( + select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id) + ).all() + if current_user.current_tenant is None: + raise ValueError("current_user.current_tenant must not be None") current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_app_list: list[dict[str, Any]] = [ { @@ -115,6 +120,8 @@ class InstalledAppsListApi(Resource): if recommended_app is None: raise NotFound("App not found") + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") current_tenant_id = current_user.current_tenant_id app = db.session.query(App).where(App.id == args["app_id"]).first() @@ -154,6 +161,8 @@ class InstalledAppApi(InstalledAppResource): """ def delete(self, installed_app): + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") if installed_app.app_owner_tenant_id == current_user.current_tenant_id: raise BadRequest("You can't uninstall an app owned by the current tenant") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 608bc6d007..c46c1c1f4f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound @@ -24,6 +23,8 @@ from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs import helper from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -54,6 +55,8 @@ class MessageListApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return MessageService.pagination_by_first_id( app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] ) @@ -75,6 +78,8 @@ class MessageFeedbackApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") MessageService.create_feedback( app_model=app_model, message_id=message_id, @@ -105,6 +110,8 @@ class MessageMoreLikeThisApi(InstalledAppResource): streaming = args["response_mode"] == "streaming" try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate_more_like_this( app_model=app_model, user=current_user, @@ -142,6 +149,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource): message_id = str(message_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE ) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index d9afb5bab2..7742ea24a9 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -20,7 +20,7 @@ class AppParameterApi(InstalledAppResource): if app_model is None: raise AppUnavailableError() - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 62f9350b71..974222ddf7 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,11 +1,10 @@ -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from constants.languages import languages from controllers.console import api from controllers.console.wraps import account_initialization_required from libs.helper import AppIconUrlField -from libs.login import login_required +from libs.login import current_user, login_required from services.recommended_app_service import RecommendedAppService app_fields = { @@ -46,8 +45,9 @@ class RecommendedAppListApi(Resource): parser.add_argument("language", type=str, location="args") args = parser.parse_args() - if args.get("language") and args.get("language") in languages: - language_prefix = args.get("language") + language = args.get("language") + if language and language in languages: + language_prefix = language elif current_user and current_user.interface_language: language_prefix = current_user.interface_language else: diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 5353dbcad5..6f05f898f9 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound @@ -8,6 +7,8 @@ from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields from libs.helper import TimestampField, uuid_value +from libs.login import current_user +from models import Account from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService @@ -42,6 +43,8 @@ class SavedMessageListApi(InstalledAppResource): parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) def post(self, installed_app): @@ -54,6 +57,8 @@ class SavedMessageListApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") SavedMessageService.save(app_model, current_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -70,6 +75,8 @@ class SavedMessageApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") SavedMessageService.delete(app_model, current_user, message_id) return {"result": "success"}, 204 diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index e157041c35..57f5ab191e 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,8 +1,8 @@ from flask_login import current_user -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from constants import HIDDEN_VALUE -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import login_required @@ -11,7 +11,21 @@ from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService +@console_ns.route("/code-based-extension") class CodeBasedExtensionAPI(Resource): + @api.doc("get_code_based_extension") + @api.doc(description="Get code-based extension data by module name") + @api.expect( + api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name") + ) + @api.response( + 200, + "Success", + api.model( + "CodeBasedExtensionResponse", + {"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")}, + ), + ) @setup_required @login_required @account_initialization_required @@ -23,7 +37,11 @@ class CodeBasedExtensionAPI(Resource): return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} +@console_ns.route("/api-based-extension") class APIBasedExtensionAPI(Resource): + @api.doc("get_api_based_extensions") + @api.doc(description="Get all API-based extensions for current tenant") + @api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields))) @setup_required @login_required @account_initialization_required @@ -32,6 +50,19 @@ class APIBasedExtensionAPI(Resource): tenant_id = current_user.current_tenant_id return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + @api.doc("create_api_based_extension") + @api.doc(description="Create a new API-based extension") + @api.expect( + api.model( + "CreateAPIBasedExtensionRequest", + { + "name": fields.String(required=True, description="Extension name"), + "api_endpoint": fields.String(required=True, description="API endpoint URL"), + "api_key": fields.String(required=True, description="API key for authentication"), + }, + ) + ) + @api.response(201, "Extension created successfully", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -53,7 +84,12 @@ class APIBasedExtensionAPI(Resource): return APIBasedExtensionService.save(extension_data) +@console_ns.route("/api-based-extension/") class APIBasedExtensionDetailAPI(Resource): + @api.doc("get_api_based_extension") + @api.doc(description="Get API-based extension by ID") + @api.doc(params={"id": "Extension ID"}) + @api.response(200, "Success", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -64,6 +100,20 @@ class APIBasedExtensionDetailAPI(Resource): return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + @api.doc("update_api_based_extension") + @api.doc(description="Update API-based extension") + @api.doc(params={"id": "Extension ID"}) + @api.expect( + api.model( + "UpdateAPIBasedExtensionRequest", + { + "name": fields.String(required=True, description="Extension name"), + "api_endpoint": fields.String(required=True, description="API endpoint URL"), + "api_key": fields.String(required=True, description="API key for authentication"), + }, + ) + ) + @api.response(200, "Extension updated successfully", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -88,6 +138,10 @@ class APIBasedExtensionDetailAPI(Resource): return APIBasedExtensionService.save(extension_data_from_db) + @api.doc("delete_api_based_extension") + @api.doc(description="Delete API-based extension") + @api.doc(params={"id": "Extension ID"}) + @api.response(204, "Extension deleted successfully") @setup_required @login_required @account_initialization_required @@ -100,9 +154,3 @@ class APIBasedExtensionDetailAPI(Resource): APIBasedExtensionService.delete(extension_data_from_db) return {"result": "success"}, 204 - - -api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") - -api.add_resource(APIBasedExtensionAPI, "/api-based-extension") -api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/") diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 6236832d39..d43b839291 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,26 +1,40 @@ from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields from libs.login import login_required from services.feature_service import FeatureService -from . import api +from . import api, console_ns from .wraps import account_initialization_required, cloud_utm_record, setup_required +@console_ns.route("/features") class FeatureApi(Resource): + @api.doc("get_tenant_features") + @api.doc(description="Get feature configuration for current tenant") + @api.response( + 200, + "Success", + api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}), + ) @setup_required @login_required @account_initialization_required @cloud_utm_record def get(self): + """Get feature configuration for current tenant""" return FeatureService.get_features(current_user.current_tenant_id).model_dump() +@console_ns.route("/system-features") class SystemFeatureApi(Resource): + @api.doc("get_system_features") + @api.doc(description="Get system-wide feature configuration") + @api.response( + 200, + "Success", + api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}), + ) def get(self): + """Get system-wide feature configuration""" return FeatureService.get_system_features().model_dump() - - -api.add_resource(FeatureApi, "/features") -api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 101a49a32e..5d11dec523 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -22,6 +22,7 @@ from controllers.console.wraps import ( ) from fields.file_fields import file_fields, upload_config_fields from libs.login import login_required +from models import Account from services.file_service import FileService PREVIEW_WORDS_LIMIT = 3000 @@ -68,6 +69,8 @@ class FileApi(Resource): source = None try: + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") upload_file = FileService.upload_file( filename=file.filename, content=file.read(), diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 2a37b1708a..30b53458b2 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,7 +1,7 @@ import os from flask import session -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session @@ -11,20 +11,47 @@ from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService -from . import api +from . import api, console_ns from .error import AlreadySetupError, InitValidateFailedError from .wraps import only_edition_self_hosted +@console_ns.route("/init") class InitValidateAPI(Resource): + @api.doc("get_init_status") + @api.doc(description="Get initialization validation status") + @api.response( + 200, + "Success", + model=api.model( + "InitStatusResponse", + {"status": fields.String(description="Initialization status", enum=["finished", "not_started"])}, + ), + ) def get(self): + """Get initialization validation status""" init_status = get_init_validate_status() if init_status: return {"status": "finished"} return {"status": "not_started"} + @api.doc("validate_init_password") + @api.doc(description="Validate initialization password for self-hosted edition") + @api.expect( + api.model( + "InitValidateRequest", + {"password": fields.String(required=True, description="Initialization password", max_length=30)}, + ) + ) + @api.response( + 201, + "Success", + model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Already setup or validation failed") @only_edition_self_hosted def post(self): + """Validate initialization password""" # is tenant created tenant_count = TenantService.get_tenant_count() if tenant_count > 0: @@ -52,6 +79,3 @@ def get_init_validate_status(): return db_session.execute(select(DifySetup)).scalar_one_or_none() return True - - -api.add_resource(InitValidateAPI, "/init") diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index 1a53a2347e..29f49b99de 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,14 +1,17 @@ -from flask_restx import Resource +from flask_restx import Resource, fields -from controllers.console import api +from . import api, console_ns +@console_ns.route("/ping") class PingApi(Resource): + @api.doc("health_check") + @api.doc(description="Health check endpoint for connection testing") + @api.response( + 200, + "Success", + api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}), + ) def get(self): - """ - For connection health check - """ + """Health check endpoint for connection testing""" return {"result": "pong"} - - -api.add_resource(PingApi, "/ping") diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 8e230496f0..bff5fc1651 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,5 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from configs import dify_config from libs.helper import StrLen, email, extract_remote_ip @@ -7,23 +7,56 @@ from libs.password import valid_password from models.model import DifySetup, db from services.account_service import RegisterService, TenantService -from . import api +from . import api, console_ns from .error import AlreadySetupError, NotInitValidateError from .init_validate import get_init_validate_status from .wraps import only_edition_self_hosted +@console_ns.route("/setup") class SetupApi(Resource): + @api.doc("get_setup_status") + @api.doc(description="Get system setup status") + @api.response( + 200, + "Success", + api.model( + "SetupStatusResponse", + { + "step": fields.String(description="Setup step status", enum=["not_started", "finished"]), + "setup_at": fields.String(description="Setup completion time (ISO format)", required=False), + }, + ), + ) def get(self): + """Get system setup status""" if dify_config.EDITION == "SELF_HOSTED": setup_status = get_setup_status() - if setup_status: + # Check if setup_status is a DifySetup object rather than a bool + if setup_status and not isinstance(setup_status, bool): return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} + elif setup_status: + return {"step": "finished"} return {"step": "not_started"} return {"step": "finished"} + @api.doc("setup_system") + @api.doc(description="Initialize system setup with admin account") + @api.expect( + api.model( + "SetupRequest", + { + "email": fields.String(required=True, description="Admin email address"), + "name": fields.String(required=True, description="Admin name (max 30 characters)"), + "password": fields.String(required=True, description="Admin password"), + }, + ) + ) + @api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")})) + @api.response(400, "Already setup or validation failed") @only_edition_self_hosted def post(self): + """Initialize system setup with admin account""" # is set up if get_setup_status(): raise AlreadySetupError() @@ -55,6 +88,3 @@ def get_setup_status(): return db.session.query(DifySetup).first() else: return True - - -api.add_resource(SetupApi, "/setup") diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index c45e7dbb26..da236ee5af 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -111,7 +111,7 @@ class TagBindingCreateApi(Resource): args = parser.parse_args() TagService.save_tag_binding(args) - return 200 + return {"result": "success"}, 200 class TagBindingDeleteApi(Resource): @@ -132,7 +132,7 @@ class TagBindingDeleteApi(Resource): args = parser.parse_args() TagService.delete_tag_binding(args) - return 200 + return {"result": "success"}, 200 api.add_resource(TagListApi, "/tags") diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 95515c38f9..8d081ad995 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,18 +2,41 @@ import json import logging import requests -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from packaging import version from configs import dify_config -from . import api +from . import api, console_ns logger = logging.getLogger(__name__) +@console_ns.route("/version") class VersionApi(Resource): + @api.doc("check_version_update") + @api.doc(description="Check for application version updates") + @api.expect( + api.parser().add_argument( + "current_version", type=str, required=True, location="args", help="Current application version" + ) + ) + @api.response( + 200, + "Success", + api.model( + "VersionResponse", + { + "version": fields.String(description="Latest version number"), + "release_date": fields.String(description="Release date of latest version"), + "release_notes": fields.String(description="Release notes for latest version"), + "can_auto_update": fields.Boolean(description="Whether auto-update is supported"), + "features": fields.Raw(description="Feature flags and capabilities"), + }, + ), + ) def get(self): + """Check for application version updates""" parser = reqparse.RequestParser() parser.add_argument("current_version", type=str, required=True, location="args") args = parser.parse_args() @@ -34,14 +57,14 @@ class VersionApi(Resource): return result try: - response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10)) + response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10)) except Exception as error: logger.warning("Check update version error: %s.", str(error)) - result["version"] = args.get("current_version") + result["version"] = args["current_version"] return result content = json.loads(response.content) - if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"): + if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"): result["version"] = content["version"] result["release_date"] = content["releaseDate"] result["release_notes"] = content["releaseNotes"] @@ -59,6 +82,3 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool: except version.InvalidVersion: logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version) return False - - -api.add_resource(VersionApi, "/version") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 5b2828dbab..7a41a8a5cc 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -49,6 +49,8 @@ class AccountInitApi(Resource): @setup_required @login_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user if account.status == "active": @@ -102,6 +104,8 @@ class AccountProfileApi(Resource): @marshal_with(account_fields) @enterprise_license_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") return current_user @@ -111,6 +115,8 @@ class AccountNameApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() @@ -130,6 +136,8 @@ class AccountAvatarApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() @@ -145,6 +153,8 @@ class AccountInterfaceLanguageApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() @@ -160,6 +170,8 @@ class AccountInterfaceThemeApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() @@ -175,6 +187,8 @@ class AccountTimezoneApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() @@ -194,6 +208,8 @@ class AccountPasswordApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("password", type=str, required=False, location="json") parser.add_argument("new_password", type=str, required=True, location="json") @@ -228,9 +244,13 @@ class AccountIntegrateApi(Resource): @account_initialization_required @marshal_with(integrate_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user - account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() + account_integrates = db.session.scalars( + select(AccountIntegrate).where(AccountIntegrate.account_id == account.id) + ).all() base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" @@ -268,6 +288,8 @@ class AccountDeleteVerifyApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user token, code = AccountService.generate_account_deletion_verification_code(account) @@ -281,6 +303,8 @@ class AccountDeleteApi(Resource): @login_required @account_initialization_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user parser = reqparse.RequestParser() @@ -321,6 +345,8 @@ class EducationVerifyApi(Resource): @cloud_edition_billing_enabled @marshal_with(verify_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user return BillingService.EducationIdentity.verify(account.id, account.email) @@ -340,6 +366,8 @@ class EducationApi(Resource): @only_edition_cloud @cloud_edition_billing_enabled def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user parser = reqparse.RequestParser() @@ -357,6 +385,8 @@ class EducationApi(Resource): @cloud_edition_billing_enabled @marshal_with(status_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user res = BillingService.EducationIdentity.status(account.id) @@ -421,6 +451,8 @@ class ChangeEmailSendEmailApi(Resource): raise InvalidTokenError() user_email = reset_data.get("email", "") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if user_email != current_user.email: raise InvalidEmailError() else: @@ -501,6 +533,8 @@ class ChangeEmailResetApi(Resource): AccountService.revoke_change_email_token(args["token"]) old_email = reset_data.get("old_email", "") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if current_user.email != old_email: raise AccountNotFound() diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 08bab6fcb5..0a2c8fcfb4 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,14 +1,22 @@ from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from libs.login import login_required from services.agent_service import AgentService +@console_ns.route("/workspaces/current/agent-providers") class AgentProviderListApi(Resource): + @api.doc("list_agent_providers") + @api.doc(description="Get list of available agent providers") + @api.response( + 200, + "Success", + fields.List(fields.Raw(description="Agent provider information")), + ) @setup_required @login_required @account_initialization_required @@ -21,7 +29,16 @@ class AgentProviderListApi(Resource): return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) +@console_ns.route("/workspaces/current/agent-provider/") class AgentProviderApi(Resource): + @api.doc("get_agent_provider") + @api.doc(description="Get specific agent provider details") + @api.doc(params={"provider_name": "Agent provider name"}) + @api.response( + 200, + "Success", + fields.Raw(description="Agent provider details"), + ) @setup_required @login_required @account_initialization_required @@ -30,7 +47,3 @@ class AgentProviderApi(Resource): user_id = user.id tenant_id = user.current_tenant_id return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name)) - - -api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers") -api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/") diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 96e873d42b..0657b764cc 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,8 +1,8 @@ from flask_login import current_user -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginPermissionDeniedError @@ -10,7 +10,26 @@ from libs.login import login_required from services.plugin.endpoint_service import EndpointService +@console_ns.route("/workspaces/current/endpoints/create") class EndpointCreateApi(Resource): + @api.doc("create_endpoint") + @api.doc(description="Create a new plugin endpoint") + @api.expect( + api.model( + "EndpointCreateRequest", + { + "plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"), + "settings": fields.Raw(required=True, description="Endpoint settings"), + "name": fields.String(required=True, description="Endpoint name"), + }, + ) + ) + @api.response( + 200, + "Endpoint created successfully", + api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -43,7 +62,20 @@ class EndpointCreateApi(Resource): raise ValueError(e.description) from e +@console_ns.route("/workspaces/current/endpoints/list") class EndpointListApi(Resource): + @api.doc("list_endpoints") + @api.doc(description="List plugin endpoints with pagination") + @api.expect( + api.parser() + .add_argument("page", type=int, required=True, location="args", help="Page number") + .add_argument("page_size", type=int, required=True, location="args", help="Page size") + ) + @api.response( + 200, + "Success", + api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}), + ) @setup_required @login_required @account_initialization_required @@ -70,7 +102,23 @@ class EndpointListApi(Resource): ) +@console_ns.route("/workspaces/current/endpoints/list/plugin") class EndpointListForSinglePluginApi(Resource): + @api.doc("list_plugin_endpoints") + @api.doc(description="List endpoints for a specific plugin") + @api.expect( + api.parser() + .add_argument("page", type=int, required=True, location="args", help="Page number") + .add_argument("page_size", type=int, required=True, location="args", help="Page size") + .add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID") + ) + @api.response( + 200, + "Success", + api.model( + "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} + ), + ) @setup_required @login_required @account_initialization_required @@ -100,7 +148,19 @@ class EndpointListForSinglePluginApi(Resource): ) +@console_ns.route("/workspaces/current/endpoints/delete") class EndpointDeleteApi(Resource): + @api.doc("delete_endpoint") + @api.doc(description="Delete a plugin endpoint") + @api.expect( + api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint deleted successfully", + api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -123,7 +183,26 @@ class EndpointDeleteApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/update") class EndpointUpdateApi(Resource): + @api.doc("update_endpoint") + @api.doc(description="Update a plugin endpoint") + @api.expect( + api.model( + "EndpointUpdateRequest", + { + "endpoint_id": fields.String(required=True, description="Endpoint ID"), + "settings": fields.Raw(required=True, description="Updated settings"), + "name": fields.String(required=True, description="Updated name"), + }, + ) + ) + @api.response( + 200, + "Endpoint updated successfully", + api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -154,7 +233,19 @@ class EndpointUpdateApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/enable") class EndpointEnableApi(Resource): + @api.doc("enable_endpoint") + @api.doc(description="Enable a plugin endpoint") + @api.expect( + api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint enabled successfully", + api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -177,7 +268,19 @@ class EndpointEnableApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/disable") class EndpointDisableApi(Resource): + @api.doc("disable_endpoint") + @api.doc(description="Disable a plugin endpoint") + @api.expect( + api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint disabled successfully", + api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -198,12 +301,3 @@ class EndpointDisableApi(Resource): tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id ) } - - -api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create") -api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list") -api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin") -api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete") -api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update") -api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable") -api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable") diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index cf2a10f453..77f0c9a735 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,8 +1,8 @@ from urllib import parse -from flask import request +from flask import abort, request from flask_login import current_user -from flask_restx import Resource, abort, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse import services from configs import dify_config @@ -41,6 +41,10 @@ class MemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) return {"result": "success", "accounts": members}, 200 @@ -65,7 +69,11 @@ class MemberInviteEmailApi(Resource): if not TenantAccountRole.is_non_owner_role(invitee_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") inviter = current_user + if not inviter.current_tenant: + raise ValueError("No current tenant") invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL @@ -76,6 +84,8 @@ class MemberInviteEmailApi(Resource): for invitee_email in invitee_emails: try: + if not inviter.current_tenant: + raise ValueError("No current tenant") token = RegisterService.invite_new_member( inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter ) @@ -97,7 +107,7 @@ class MemberInviteEmailApi(Resource): return { "result": "success", "invitation_results": invitation_results, - "tenant_id": str(current_user.current_tenant.id), + "tenant_id": str(inviter.current_tenant.id) if inviter.current_tenant else "", }, 201 @@ -108,6 +118,10 @@ class MemberCancelInviteApi(Resource): @login_required @account_initialization_required def delete(self, member_id): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") member = db.session.query(Account).where(Account.id == str(member_id)).first() if member is None: abort(404) @@ -123,7 +137,10 @@ class MemberCancelInviteApi(Resource): except Exception as e: raise ValueError(str(e)) - return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200 + return { + "result": "success", + "tenant_id": str(current_user.current_tenant.id) if current_user.current_tenant else "", + }, 200 class MemberUpdateRoleApi(Resource): @@ -141,6 +158,10 @@ class MemberUpdateRoleApi(Resource): if not TenantAccountRole.is_valid_role(new_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") member = db.session.get(Account, str(member_id)) if not member: abort(404) @@ -164,6 +185,10 @@ class DatasetOperatorMemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) return {"result": "success", "accounts": members}, 200 @@ -184,6 +209,10 @@ class SendOwnerTransferEmailApi(Resource): raise EmailSendIpLimitError() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -198,7 +227,7 @@ class SendOwnerTransferEmailApi(Resource): account=current_user, email=email, language=language, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", ) return {"result": "success", "data": token} @@ -215,6 +244,10 @@ class OwnerTransferCheckApi(Resource): parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -256,6 +289,10 @@ class OwnerTransfer(Resource): args = parser.parse_args() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -274,9 +311,11 @@ class OwnerTransfer(Resource): member = db.session.get(Account, str(member_id)) if not member: abort(404) - else: - member_account = member - if not TenantService.is_member(member_account, current_user.current_tenant): + return # Never reached, but helps type checker + + if not current_user.current_tenant: + raise ValueError("No current tenant") + if not TenantService.is_member(member, current_user.current_tenant): raise MemberNotInTenantError() try: @@ -286,13 +325,13 @@ class OwnerTransfer(Resource): AccountService.send_new_owner_transfer_notify_email( account=member, email=member.email, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", ) AccountService.send_old_owner_transfer_notify_email( account=current_user, email=current_user.email, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", new_owner_email=member.email, ) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index bfcc9a7f0a..0c9db660aa 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import StrLen, uuid_value from libs.login import login_required +from models.account import Account from services.billing_service import BillingService from services.model_provider_service import ModelProviderService @@ -21,6 +22,10 @@ class ModelProviderListApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() @@ -45,6 +50,10 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def get(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id # if credential_id is not provided, return current used credential parser = reqparse.RequestParser() @@ -62,6 +71,8 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() @@ -72,6 +83,8 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") try: model_provider_service.create_provider_credential( tenant_id=current_user.current_tenant_id, @@ -88,6 +101,8 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def put(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() @@ -99,6 +114,8 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") try: model_provider_service.update_provider_credential( tenant_id=current_user.current_tenant_id, @@ -116,12 +133,16 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def delete(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") model_provider_service = ModelProviderService() model_provider_service.remove_provider_credential( tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] @@ -135,12 +156,16 @@ class ModelProviderCredentialSwitchApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") service = ModelProviderService() service.switch_active_provider_credential( tenant_id=current_user.current_tenant_id, @@ -155,10 +180,14 @@ class ModelProviderValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() @@ -205,9 +234,13 @@ class PreferredProviderTypeUpdateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() @@ -236,7 +269,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): def get(self, provider: str): if provider != "anthropic": raise ValueError(f"provider name {provider} is invalid") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") BillingService.is_tenant_owner_or_admin(current_user) + if not current_user.current_tenant_id: + raise ValueError("No current tenant") data = BillingService.get_model_provider_payment_link( provider_name=provider, tenant_id=current_user.current_tenant_id, diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index e7a3aca66c..655afbe73f 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -25,7 +25,7 @@ from controllers.console.wraps import ( from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required -from models.account import Tenant, TenantStatus +from models.account import Account, Tenant, TenantStatus from services.account_service import TenantService from services.feature_service import FeatureService from services.file_service import FileService @@ -70,6 +70,8 @@ class TenantListApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") tenants = TenantService.get_join_tenants(current_user) tenant_dicts = [] @@ -83,7 +85,7 @@ class TenantListApi(Resource): "status": tenant.status, "created_at": tenant.created_at, "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", - "current": tenant.id == current_user.current_tenant_id, + "current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False, } tenant_dicts.append(tenant_dict) @@ -125,7 +127,11 @@ class TenantApi(Resource): if request.path == "/info": logger.warning("Deprecated URL /info was used.") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") tenant = current_user.current_tenant + if not tenant: + raise ValueError("No current tenant") if tenant.status == TenantStatus.ARCHIVE: tenants = TenantService.get_join_tenants(current_user) @@ -137,6 +143,8 @@ class TenantApi(Resource): else: raise Unauthorized("workspace is archived") + if not tenant: + raise ValueError("No tenant available") return WorkspaceService.get_tenant_info(tenant), 200 @@ -145,6 +153,8 @@ class SwitchWorkspaceApi(Resource): @login_required @account_initialization_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() @@ -168,11 +178,15 @@ class CustomConfigWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("remove_webapp_brand", type=bool, location="json") parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant = db.get_or_404(Tenant, current_user.current_tenant_id) custom_config_dict = { @@ -194,6 +208,8 @@ class WebappLogoWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") # check file if "file" not in request.files: raise NoFileUploadedError() @@ -232,10 +248,14 @@ class WorkspaceInfoApi(Resource): @account_initialization_required # Change workspace name def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant.name = args["name"] db.session.commit() diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index e375fe285b..092071481e 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -242,6 +242,19 @@ def email_password_login_enabled(view: Callable[P, R]): return decorated +def email_register_enabled(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if features.is_allow_register: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated + + def enable_change_email(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 821ad220a2..f8976b86b9 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Files API", description="API for file operations including upload and preview", - doc="/docs", # Enable Swagger UI at /files/docs ) files_ns = Namespace("files", description="File operations", path="/") @@ -18,3 +17,12 @@ files_ns = Namespace("files", description="File operations", path="/") from . import image_preview, tool_files, upload api.add_namespace(files_ns) + +__all__ = [ + "api", + "bp", + "files_ns", + "image_preview", + "tool_files", + "upload", +] diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 7a2b3b0428..21cbf56074 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -86,7 +86,7 @@ class PluginUploadFileApi(Resource): filename=filename, mimetype=mimetype, tenant_id=tenant_id, - user_id=user_id, + user_id=user.id, timestamp=timestamp, nonce=nonce, sign=sign, diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index d29a7be139..74005217ef 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -10,14 +10,22 @@ api = ExternalApi( version="1.0", title="Inner API", description="Internal APIs for enterprise features, billing, and plugin communication", - doc="/docs", # Enable Swagger UI at /inner/api/docs ) # Create namespace inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") -from . import mail -from .plugin import plugin -from .workspace import workspace +from . import mail as _mail +from .plugin import plugin as _plugin +from .workspace import workspace as _workspace api.add_namespace(inner_api_ns) + +__all__ = [ + "_mail", + "_plugin", + "_workspace", + "api", + "bp", + "inner_api_ns", +] diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 170a794d89..c5bb2f2545 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -37,9 +37,9 @@ from models.model import EndUser @inner_api_ns.route("/invoke/llm") class PluginInvokeLLMApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeLLM) @inner_api_ns.doc("plugin_invoke_llm") @inner_api_ns.doc(description="Invoke LLM models through plugin interface") @@ -60,9 +60,9 @@ class PluginInvokeLLMApi(Resource): @inner_api_ns.route("/invoke/llm/structured-output") class PluginInvokeLLMWithStructuredOutputApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput) @inner_api_ns.doc("plugin_invoke_llm_structured") @inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface") @@ -85,9 +85,9 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource): @inner_api_ns.route("/invoke/text-embedding") class PluginInvokeTextEmbeddingApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTextEmbedding) @inner_api_ns.doc("plugin_invoke_text_embedding") @inner_api_ns.doc(description="Invoke text embedding models through plugin interface") @@ -115,9 +115,9 @@ class PluginInvokeTextEmbeddingApi(Resource): @inner_api_ns.route("/invoke/rerank") class PluginInvokeRerankApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeRerank) @inner_api_ns.doc("plugin_invoke_rerank") @inner_api_ns.doc(description="Invoke rerank models through plugin interface") @@ -141,9 +141,9 @@ class PluginInvokeRerankApi(Resource): @inner_api_ns.route("/invoke/tts") class PluginInvokeTTSApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTTS) @inner_api_ns.doc("plugin_invoke_tts") @inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface") @@ -168,9 +168,9 @@ class PluginInvokeTTSApi(Resource): @inner_api_ns.route("/invoke/speech2text") class PluginInvokeSpeech2TextApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeSpeech2Text) @inner_api_ns.doc("plugin_invoke_speech2text") @inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface") @@ -194,9 +194,9 @@ class PluginInvokeSpeech2TextApi(Resource): @inner_api_ns.route("/invoke/moderation") class PluginInvokeModerationApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeModeration) @inner_api_ns.doc("plugin_invoke_moderation") @inner_api_ns.doc(description="Invoke moderation models through plugin interface") @@ -220,9 +220,9 @@ class PluginInvokeModerationApi(Resource): @inner_api_ns.route("/invoke/tool") class PluginInvokeToolApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTool) @inner_api_ns.doc("plugin_invoke_tool") @inner_api_ns.doc(description="Invoke tools through plugin interface") @@ -252,9 +252,9 @@ class PluginInvokeToolApi(Resource): @inner_api_ns.route("/invoke/parameter-extractor") class PluginInvokeParameterExtractorNodeApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeParameterExtractorNode) @inner_api_ns.doc("plugin_invoke_parameter_extractor") @inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface") @@ -285,9 +285,9 @@ class PluginInvokeParameterExtractorNodeApi(Resource): @inner_api_ns.route("/invoke/question-classifier") class PluginInvokeQuestionClassifierNodeApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) @inner_api_ns.doc("plugin_invoke_question_classifier") @inner_api_ns.doc(description="Invoke question classifier node through plugin interface") @@ -318,9 +318,9 @@ class PluginInvokeQuestionClassifierNodeApi(Resource): @inner_api_ns.route("/invoke/app") class PluginInvokeAppApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeApp) @inner_api_ns.doc("plugin_invoke_app") @inner_api_ns.doc(description="Invoke application through plugin interface") @@ -348,9 +348,9 @@ class PluginInvokeAppApi(Resource): @inner_api_ns.route("/invoke/encrypt") class PluginInvokeEncryptApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeEncrypt) @inner_api_ns.doc("plugin_invoke_encrypt") @inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface") @@ -375,9 +375,9 @@ class PluginInvokeEncryptApi(Resource): @inner_api_ns.route("/invoke/summary") class PluginInvokeSummaryApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeSummary) @inner_api_ns.doc("plugin_invoke_summary") @inner_api_ns.doc(description="Invoke summary functionality through plugin interface") @@ -405,9 +405,9 @@ class PluginInvokeSummaryApi(Resource): @inner_api_ns.route("/upload/file/request") class PluginUploadFileRequestApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestRequestUploadFile) @inner_api_ns.doc("plugin_upload_file_request") @inner_api_ns.doc(description="Request signed URL for file upload through plugin interface") @@ -426,9 +426,9 @@ class PluginUploadFileRequestApi(Resource): @inner_api_ns.route("/fetch/app/info") class PluginFetchAppInfoApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestFetchAppInfo) @inner_api_ns.doc("plugin_fetch_app_info") @inner_api_ns.doc(description="Fetch application information through plugin interface") diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index f751e06ddf..369e421599 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional +from typing import Optional, ParamSpec, TypeVar, cast from flask import current_app, request from flask_login import user_logged_in @@ -8,11 +8,13 @@ from flask_restx import reqparse from pydantic import BaseModel from sqlalchemy.orm import Session -from core.file.constants import DEFAULT_SERVICE_API_USER_ID from extensions.ext_database import db -from libs.login import _get_user +from libs.login import current_user from models.account import Tenant -from models.model import EndUser +from models.model import DefaultEndUserSessionID, EndUser + +P = ParamSpec("P") +R = TypeVar("R") def get_user(tenant_id: str, user_id: str | None) -> EndUser: @@ -25,7 +27,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: try: with Session(db.engine) as session: if not user_id: - user_id = DEFAULT_SERVICE_API_USER_ID + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value user_model = ( session.query(EndUser) @@ -39,7 +41,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: user_model = EndUser( tenant_id=tenant_id, type="service_api", - is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID, + is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value, session_id=user_id, ) session.add(user_model) @@ -52,28 +54,25 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: return user_model -def get_user_tenant(view: Optional[Callable] = None): - def decorator(view_func): +def get_user_tenant(view: Optional[Callable[P, R]] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): # fetch json body parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") parser.add_argument("user_id", type=str, required=True, location="json") - kwargs = parser.parse_args() + p = parser.parse_args() - user_id = kwargs.get("user_id") - tenant_id = kwargs.get("tenant_id") + user_id = cast(str, p.get("user_id")) + tenant_id = cast(str, p.get("tenant_id")) if not tenant_id: raise ValueError("tenant_id is required") if not user_id: - user_id = DEFAULT_SERVICE_API_USER_ID - - del kwargs["tenant_id"] - del kwargs["user_id"] + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value try: tenant_model = ( @@ -95,7 +94,7 @@ def get_user_tenant(view: Optional[Callable] = None): kwargs["user_model"] = user current_app.login_manager._update_request_context_with_user(user) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore return view_func(*args, **kwargs) @@ -107,9 +106,9 @@ def get_user_tenant(view: Optional[Callable] = None): return decorator(view) -def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]): - def decorator(view_func): - def decorated_view(*args, **kwargs): +def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]): + def decorator(view_func: Callable[P, R]): + def decorated_view(*args: P.args, **kwargs: P.kwargs): try: data = request.get_json() except Exception: diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index de4f1da801..4bdcc6832a 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -46,9 +46,9 @@ def enterprise_inner_api_only(view: Callable[P, R]): return decorated -def enterprise_inner_api_user_auth(view): +def enterprise_inner_api_user_auth(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if not dify_config.INNER_API: return view(*args, **kwargs) diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py index c344ffad08..d6fb2981e4 100644 --- a/api/controllers/mcp/__init__.py +++ b/api/controllers/mcp/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="MCP API", description="API for Model Context Protocol operations", - doc="/docs", # Enable Swagger UI at /mcp/docs ) mcp_ns = Namespace("mcp", description="MCP operations", path="/") @@ -18,3 +17,10 @@ mcp_ns = Namespace("mcp", description="MCP operations", path="/") from . import mcp api.add_namespace(mcp_ns) + +__all__ = [ + "api", + "bp", + "mcp", + "mcp_ns", +] diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 43b59d5334..80e6037cc7 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -150,7 +150,7 @@ class MCPAppApi(Resource): def _get_user_input_form(self, app: App) -> list[VariableEntity]: """Get and convert user input form""" # Get raw user input form based on app mode - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if not app.workflow: raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable") raw_user_input_form = app.workflow.user_input_form(to_old_structure=True) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 763345d723..9032733e2c 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -10,14 +10,50 @@ api = ExternalApi( version="1.0", title="Service API", description="API for application services", - doc="/docs", # Enable Swagger UI at /v1/docs ) service_api_ns = Namespace("service_api", description="Service operations", path="/") from . import index -from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow -from .dataset import dataset, document, hit_testing, metadata, segment, upload_file +from .app import ( + annotation, + app, + audio, + completion, + conversation, + file, + file_preview, + message, + site, + workflow, +) +from .dataset import ( + dataset, + document, + hit_testing, + metadata, + segment, +) from .workspace import models +__all__ = [ + "annotation", + "app", + "audio", + "completion", + "conversation", + "dataset", + "document", + "file", + "file_preview", + "hit_testing", + "index", + "message", + "metadata", + "models", + "segment", + "site", + "workflow", +] + api.add_namespace(service_api_ns) diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 9038bda11a..ad1bdc7334 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -165,7 +165,7 @@ class AnnotationUpdateDeleteApi(Resource): def put(self, app_model: App, annotation_id): """Update an existing annotation.""" assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() annotation_id = str(annotation_id) @@ -189,7 +189,7 @@ class AnnotationUpdateDeleteApi(Resource): """Delete an annotation.""" assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() annotation_id = str(annotation_id) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 2dbeed1d68..25d7ccccec 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -29,7 +29,7 @@ class AppParameterApi(Resource): Returns the input form parameters and configuration for the application. """ - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 4860bf3a79..711dd5704c 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,4 +1,5 @@ from flask_restx import Resource, reqparse +from flask_restx._http import HTTPStatus from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound @@ -121,7 +122,7 @@ class ConversationDetailApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204) + @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT) def delete(self, app_model: App, end_user: EndUser, c_id): """Delete a specific conversation.""" app_mode = AppMode.value_of(app_model.mode) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 580b08b9f0..99fde12e34 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -340,6 +340,9 @@ class DatasetApi(DatasetApiResource): else: data["embedding_available"] = True + # force update search method to keyword_search if indexing_technique is economic + data["retrieval_model_dict"]["search_method"] = "keyword_search" + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) data.update({"partial_member_list": part_users_list}) @@ -559,7 +562,7 @@ class DatasetTagsApi(DatasetApiResource): def post(self, _, dataset_id): """Add a knowledge type tag.""" assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_create_parser.parse_args() @@ -583,7 +586,7 @@ class DatasetTagsApi(DatasetApiResource): @validate_dataset_token def patch(self, _, dataset_id): assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_update_parser.parse_args() @@ -610,7 +613,7 @@ class DatasetTagsApi(DatasetApiResource): def delete(self, _, dataset_id): """Delete a knowledge type tag.""" assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() args = tag_delete_parser.parse_args() TagService.delete_tag(args.get("tag_id")) @@ -634,7 +637,7 @@ class DatasetTagBindingApi(DatasetApiResource): def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_binding_parser.parse_args() @@ -660,7 +663,7 @@ class DatasetTagUnbindingApi(DatasetApiResource): def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_unbinding_parser.parse_args() diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index de41384270..721cf530c3 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -30,6 +30,7 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment +from models.model import EndUser from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -298,6 +299,9 @@ class DocumentAddByFileApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError + if not isinstance(current_user, EndUser): + raise ValueError("Invalid user account") + upload_file = FileService.upload_file( filename=file.filename, content=file.read(), @@ -387,6 +391,8 @@ class DocumentUpdateByFileApi(DatasetApiResource): raise FilenameNotExistsError try: + if not isinstance(current_user, EndUser): + raise ValueError("Invalid user account") upload_file = FileService.upload_file( filename=file.filename, content=file.read(), diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 444a791c01..c2df97eaec 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -174,7 +174,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): MetadataService.enable_built_in_field(dataset) elif action == "disable": MetadataService.disable_built_in_field(dataset) - return 200 + return {"result": "success"}, 200 @service_api_ns.route("/datasets//documents/metadata") @@ -204,4 +204,4 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): MetadataService.update_documents_metadata(dataset, metadata_args) - return 200 + return {"result": "success"}, 200 diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py deleted file mode 100644 index 27b36a6402..0000000000 --- a/api/controllers/service_api/dataset/upload_file.py +++ /dev/null @@ -1,65 +0,0 @@ -from werkzeug.exceptions import NotFound - -from controllers.service_api import service_api_ns -from controllers.service_api.wraps import ( - DatasetApiResource, -) -from core.file import helpers as file_helpers -from extensions.ext_database import db -from models.dataset import Dataset -from models.model import UploadFile -from services.dataset_service import DocumentService - - -@service_api_ns.route("/datasets//documents//upload-file") -class UploadFileApi(DatasetApiResource): - @service_api_ns.doc("get_upload_file") - @service_api_ns.doc(description="Get upload file information and download URL") - @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) - @service_api_ns.doc( - responses={ - 200: "Upload file information retrieved successfully", - 401: "Unauthorized - invalid API token", - 404: "Dataset, document, or upload file not found", - } - ) - def get(self, tenant_id, dataset_id, document_id): - """Get upload file information and download URL. - - Returns information about an uploaded file including its download URL. - """ - # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: - raise NotFound("Dataset not found.") - # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset.id, document_id) - if not document: - raise NotFound("Document not found.") - # check upload file - if document.data_source_type != "upload_file": - raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.") - data_source_info = document.data_source_info_dict - if data_source_info and "upload_file_id" in data_source_info: - file_id = data_source_info["upload_file_id"] - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() - if not upload_file: - raise NotFound("UploadFile not found.") - else: - raise ValueError("Upload file id not found in document data source info.") - - url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": url, - "download_url": f"{url}&as_attachment=true", - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at.timestamp(), - }, 200 diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 536cf81a2f..fffcb47bd4 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -19,7 +19,7 @@ class ModelProviderAvailableModelApi(Resource): } ) @validate_dataset_token - def get(self, _, model_type): + def get(self, _, model_type: str): """Get available models by model type. Returns a list of available models for the specified model type. diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 14291578d5..c38465c98e 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -3,7 +3,7 @@ from collections.abc import Callable from datetime import timedelta from enum import StrEnum, auto from functools import wraps -from typing import Optional, ParamSpec, TypeVar +from typing import Concatenate, Optional, ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in @@ -13,18 +13,18 @@ from sqlalchemy import select, update from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound, Unauthorized -from core.file.constants import DEFAULT_SERVICE_API_USER_ID from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now -from libs.login import _get_user +from libs.login import current_user from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.dataset import Dataset, RateLimitLog -from models.model import ApiToken, App, EndUser +from models.model import ApiToken, App, DefaultEndUserSessionID, EndUser from services.feature_service import FeatureService P = ParamSpec("P") R = TypeVar("R") +T = TypeVar("T") class WhereisUserArg(StrEnum): @@ -42,10 +42,10 @@ class FetchUserArg(BaseModel): required: bool = False -def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): - def decorator(view_func): +def validate_app_token(view: Optional[Callable[P, R]] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token("app") app_model = db.session.query(App).where(App.id == api_token.app_id).first() @@ -189,10 +189,10 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): return interceptor -def validate_dataset_token(view=None): - def decorator(view): +def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None): + def decorator(view: Callable[Concatenate[T, P], R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token("dataset") tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) @@ -209,7 +209,7 @@ def validate_dataset_token(view=None): if account: account.current_tenant = tenant current_app.login_manager._update_request_context_with_user(account) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore else: raise Unauthorized("Tenant owner account does not exist.") else: @@ -272,7 +272,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] Create or update session terminal based on user ID. """ if not user_id: - user_id = DEFAULT_SERVICE_API_USER_ID + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value with Session(db.engine, expire_on_commit=False) as session: end_user = ( @@ -291,7 +291,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] tenant_id=app_model.tenant_id, app_id=app_model.id, type="service_api", - is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID, + is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value, session_id=user_id, ) session.add(end_user) diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 3b0a9e341a..1d22954308 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Web API", description="Public APIs for web applications including file uploads, chat interactions, and app management", - doc="/docs", # Enable Swagger UI at /api/docs ) # Create namespace @@ -34,3 +33,23 @@ from . import ( ) api.add_namespace(web_ns) + +__all__ = [ + "api", + "app", + "audio", + "bp", + "completion", + "conversation", + "feature", + "files", + "forgot_password", + "login", + "message", + "passport", + "remote_files", + "saved_message", + "site", + "web_ns", + "workflow", +] diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index e0c3e997ce..2bc068ec75 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -38,7 +38,7 @@ class AppParameterApi(WebApiResource): @marshal_with(fields.parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 2c0f6c9759..c1c46891b6 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -5,7 +5,7 @@ from flask_restx import fields, marshal_with, reqparse from werkzeug.exceptions import InternalServerError import services -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, AudioTooLargeError, @@ -32,15 +32,16 @@ from services.errors.audio import ( logger = logging.getLogger(__name__) +@web_ns.route("/audio-to-text") class AudioApi(WebApiResource): audio_to_text_response_fields = { "text": fields.String, } @marshal_with(audio_to_text_response_fields) - @api.doc("Audio to Text") - @api.doc(description="Convert audio file to text using speech-to-text service.") - @api.doc( + @web_ns.doc("Audio to Text") + @web_ns.doc(description="Convert audio file to text using speech-to-text service.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -85,6 +86,7 @@ class AudioApi(WebApiResource): raise InternalServerError() +@web_ns.route("/text-to-audio") class TextApi(WebApiResource): text_to_audio_response_fields = { "audio_url": fields.String, @@ -92,9 +94,9 @@ class TextApi(WebApiResource): } @marshal_with(text_to_audio_response_fields) - @api.doc("Text to Audio") - @api.doc(description="Convert text to audio using text-to-speech service.") - @api.doc( + @web_ns.doc("Text to Audio") + @web_ns.doc(description="Convert text to audio using text-to-speech service.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -145,7 +147,3 @@ class TextApi(WebApiResource): except Exception as e: logger.exception("Failed to handle post request to TextApi") raise InternalServerError() - - -api.add_resource(AudioApi, "/audio-to-text") -api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index a42bf5fc6e..67ae970388 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -4,7 +4,7 @@ from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, CompletionRequestError, @@ -35,10 +35,11 @@ logger = logging.getLogger(__name__) # define completion api for user +@web_ns.route("/completion-messages") class CompletionApi(WebApiResource): - @api.doc("Create Completion Message") - @api.doc(description="Create a completion message for text generation applications.") - @api.doc( + @web_ns.doc("Create Completion Message") + @web_ns.doc(description="Create a completion message for text generation applications.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the completion", "type": "object", "required": True}, "query": {"description": "Query text for completion", "type": "string", "required": False}, @@ -52,7 +53,7 @@ class CompletionApi(WebApiResource): "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -106,11 +107,12 @@ class CompletionApi(WebApiResource): raise InternalServerError() +@web_ns.route("/completion-messages//stop") class CompletionStopApi(WebApiResource): - @api.doc("Stop Completion Message") - @api.doc(description="Stop a running completion message task.") - @api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) - @api.doc( + @web_ns.doc("Stop Completion Message") + @web_ns.doc(description="Stop a running completion message task.") + @web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -129,10 +131,11 @@ class CompletionStopApi(WebApiResource): return {"result": "success"}, 200 +@web_ns.route("/chat-messages") class ChatApi(WebApiResource): - @api.doc("Create Chat Message") - @api.doc(description="Create a chat message for conversational applications.") - @api.doc( + @web_ns.doc("Create Chat Message") + @web_ns.doc(description="Create a chat message for conversational applications.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the chat", "type": "object", "required": True}, "query": {"description": "User query/message", "type": "string", "required": True}, @@ -148,7 +151,7 @@ class ChatApi(WebApiResource): "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -207,11 +210,12 @@ class ChatApi(WebApiResource): raise InternalServerError() +@web_ns.route("/chat-messages//stop") class ChatStopApi(WebApiResource): - @api.doc("Stop Chat Message") - @api.doc(description="Stop a running chat message task.") - @api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) - @api.doc( + @web_ns.doc("Stop Chat Message") + @web_ns.doc(description="Stop a running chat message task.") + @web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -229,9 +233,3 @@ class ChatStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {"result": "success"}, 200 - - -api.add_resource(CompletionApi, "/completion-messages") -api.add_resource(CompletionStopApi, "/completion-messages//stop") -api.add_resource(ChatApi, "/chat-messages") -api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 24de4f3f2e..03dd986aed 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -3,7 +3,7 @@ from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom @@ -16,7 +16,44 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers from services.web_conversation_service import WebConversationService +@web_ns.route("/conversations") class ConversationListApi(WebApiResource): + @web_ns.doc("Get Conversation List") + @web_ns.doc(description="Retrieve paginated list of conversations for a chat application.") + @web_ns.doc( + params={ + "last_id": {"description": "Last conversation ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of conversations to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + "pinned": { + "description": "Filter by pinned status", + "type": "string", + "enum": ["true", "false"], + "required": False, + }, + "sort_by": { + "description": "Sort order", + "type": "string", + "enum": ["created_at", "-created_at", "updated_at", "-updated_at"], + "required": False, + "default": "-updated_at", + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "App Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -57,11 +94,25 @@ class ConversationListApi(WebApiResource): raise NotFound("Last Conversation Not Exists.") +@web_ns.route("/conversations/") class ConversationApi(WebApiResource): delete_response_fields = { "result": fields.String, } + @web_ns.doc("Delete Conversation") + @web_ns.doc(description="Delete a specific conversation.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 204: "Conversation deleted successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(delete_response_fields) def delete(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -76,7 +127,32 @@ class ConversationApi(WebApiResource): return {"result": "success"}, 204 +@web_ns.route("/conversations//name") class ConversationRenameApi(WebApiResource): + @web_ns.doc("Rename Conversation") + @web_ns.doc(description="Rename a specific conversation with a custom name or auto-generate one.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + params={ + "name": {"description": "New conversation name", "type": "string", "required": False}, + "auto_generate": { + "description": "Auto-generate conversation name", + "type": "boolean", + "required": False, + "default": False, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Conversation renamed successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -96,11 +172,25 @@ class ConversationRenameApi(WebApiResource): raise NotFound("Conversation Not Exists.") +@web_ns.route("/conversations//pin") class ConversationPinApi(WebApiResource): pin_response_fields = { "result": fields.String, } + @web_ns.doc("Pin Conversation") + @web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Conversation pinned successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(pin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -117,11 +207,25 @@ class ConversationPinApi(WebApiResource): return {"result": "success"} +@web_ns.route("/conversations//unpin") class ConversationUnPinApi(WebApiResource): unpin_response_fields = { "result": fields.String, } + @web_ns.doc("Unpin Conversation") + @web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Conversation unpinned successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(unpin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -132,10 +236,3 @@ class ConversationUnPinApi(WebApiResource): WebConversationService.unpin(app_model, conversation_id, end_user) return {"result": "success"} - - -api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="web_conversation_name") -api.add_resource(ConversationListApi, "/conversations") -api.add_resource(ConversationApi, "/conversations/") -api.add_resource(ConversationPinApi, "/conversations//pin") -api.add_resource(ConversationUnPinApi, "/conversations//unpin") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 17e06e8856..26c0b133d9 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -4,7 +4,7 @@ from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, @@ -38,6 +38,7 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) +@web_ns.route("/messages") class MessageListApi(WebApiResource): message_fields = { "id": fields.String, @@ -62,6 +63,30 @@ class MessageListApi(WebApiResource): "data": fields.List(fields.Nested(message_fields)), } + @web_ns.doc("Get Message List") + @web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.") + @web_ns.doc( + params={ + "conversation_id": {"description": "Conversation UUID", "type": "string", "required": True}, + "first_id": {"description": "First message ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of messages to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -84,11 +109,36 @@ class MessageListApi(WebApiResource): raise NotFound("First Message Not Exists.") +@web_ns.route("/messages//feedbacks") class MessageFeedbackApi(WebApiResource): feedback_response_fields = { "result": fields.String, } + @web_ns.doc("Create Message Feedback") + @web_ns.doc(description="Submit feedback (like/dislike) for a specific message.") + @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) + @web_ns.doc( + params={ + "rating": { + "description": "Feedback rating", + "type": "string", + "enum": ["like", "dislike"], + "required": False, + }, + "content": {"description": "Feedback content/comment", "type": "string", "required": False}, + } + ) + @web_ns.doc( + responses={ + 200: "Feedback submitted successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(feedback_response_fields) def post(self, app_model, end_user, message_id): message_id = str(message_id) @@ -112,7 +162,31 @@ class MessageFeedbackApi(WebApiResource): return {"result": "success"} +@web_ns.route("/messages//more-like-this") class MessageMoreLikeThisApi(WebApiResource): + @web_ns.doc("Generate More Like This") + @web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).") + @web_ns.doc( + params={ + "message_id": {"description": "Message UUID", "type": "string", "required": True}, + "response_mode": { + "description": "Response mode", + "type": "string", + "enum": ["blocking", "streaming"], + "required": True, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a completion app or feature disabled", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) def get(self, app_model, end_user, message_id): if app_model.mode != "completion": raise NotCompletionAppError() @@ -156,11 +230,25 @@ class MessageMoreLikeThisApi(WebApiResource): raise InternalServerError() +@web_ns.route("/messages//suggested-questions") class MessageSuggestedQuestionApi(WebApiResource): suggested_questions_response_fields = { "data": fields.List(fields.String), } + @web_ns.doc("Get Suggested Questions") + @web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).") + @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a chat app or feature disabled", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found or Conversation Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(suggested_questions_response_fields) def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) @@ -192,9 +280,3 @@ class MessageSuggestedQuestionApi(WebApiResource): raise InternalServerError() return {"data": questions} - - -api.add_resource(MessageListApi, "/messages") -api.add_resource(MessageFeedbackApi, "/messages//feedbacks") -api.add_resource(MessageMoreLikeThisApi, "/messages//more-like-this") -api.add_resource(MessageSuggestedQuestionApi, "/messages//suggested-questions") diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 7a9d24114e..96f09c8d3c 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import message_file_fields @@ -23,6 +23,7 @@ message_fields = { } +@web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { "limit": fields.Integer, @@ -34,6 +35,29 @@ class SavedMessageListApi(WebApiResource): "result": fields.String, } + @web_ns.doc("Get Saved Messages") + @web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.") + @web_ns.doc( + params={ + "last_id": {"description": "Last message ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of messages to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "App Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): if app_model.mode != "completion": @@ -46,6 +70,23 @@ class SavedMessageListApi(WebApiResource): return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) + @web_ns.doc("Save Message") + @web_ns.doc(description="Save a specific message for later reference.") + @web_ns.doc( + params={ + "message_id": {"description": "Message UUID to save", "type": "string", "required": True}, + } + ) + @web_ns.doc( + responses={ + 200: "Message saved successfully", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(post_response_fields) def post(self, app_model, end_user): if app_model.mode != "completion": @@ -63,11 +104,25 @@ class SavedMessageListApi(WebApiResource): return {"result": "success"} +@web_ns.route("/saved-messages/") class SavedMessageApi(WebApiResource): delete_response_fields = { "result": fields.String, } + @web_ns.doc("Delete Saved Message") + @web_ns.doc(description="Remove a message from saved messages.") + @web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 204: "Message removed successfully", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(delete_response_fields) def delete(self, app_model, end_user, message_id): message_id = str(message_id) @@ -78,7 +133,3 @@ class SavedMessageApi(WebApiResource): SavedMessageService.delete(app_model, end_user, message_id) return {"result": "success"}, 204 - - -api.add_resource(SavedMessageListApi, "/saved-messages") -api.add_resource(SavedMessageApi, "/saved-messages/") diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 91d67bf9d8..b01aaba357 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.web import api +from controllers.web import web_ns from controllers.web.wraps import WebApiResource from extensions.ext_database import db from libs.helper import AppIconUrlField @@ -11,6 +11,7 @@ from models.model import Site from services.feature_service import FeatureService +@web_ns.route("/site") class AppSiteApi(WebApiResource): """Resource for app sites.""" @@ -53,9 +54,9 @@ class AppSiteApi(WebApiResource): "custom_config": fields.Raw(attribute="custom_config"), } - @api.doc("Get App Site Info") - @api.doc(description="Retrieve app site information and configuration.") - @api.doc( + @web_ns.doc("Get App Site Info") + @web_ns.doc(description="Retrieve app site information and configuration.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -82,9 +83,6 @@ class AppSiteApi(WebApiResource): return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) -api.add_resource(AppSiteApi, "/site") - - class AppSiteInfo: """Class to store site information.""" diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 3566cfae38..490dce8f05 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -3,7 +3,7 @@ import logging from flask_restx import reqparse from werkzeug.exceptions import InternalServerError -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( CompletionRequestError, NotWorkflowAppError, @@ -29,16 +29,17 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) +@web_ns.route("/workflows/run") class WorkflowRunApi(WebApiResource): - @api.doc("Run Workflow") - @api.doc(description="Execute a workflow with provided inputs and files.") - @api.doc( + @web_ns.doc("Run Workflow") + @web_ns.doc(description="Execute a workflow with provided inputs and files.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the workflow", "type": "object", "required": True}, "files": {"description": "Files to be processed by the workflow", "type": "array", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -84,15 +85,16 @@ class WorkflowRunApi(WebApiResource): raise InternalServerError() +@web_ns.route("/workflows/tasks//stop") class WorkflowTaskStopApi(WebApiResource): - @api.doc("Stop Workflow Task") - @api.doc(description="Stop a running workflow task.") - @api.doc( + @web_ns.doc("Stop Workflow Task") + @web_ns.doc(description="Stop a running workflow task.") + @web_ns.doc( params={ "task_id": {"description": "Task ID to stop", "type": "string", "required": True}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -113,7 +115,3 @@ class WorkflowTaskStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {"result": "success"} - - -api.add_resource(WorkflowRunApi, "/workflows/run") -api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 1fbb2c165f..e79456535a 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,6 +1,7 @@ +from collections.abc import Callable from datetime import UTC, datetime from functools import wraps -from typing import ParamSpec, TypeVar +from typing import Concatenate, Optional, ParamSpec, TypeVar from flask import request from flask_restx import Resource @@ -20,12 +21,11 @@ P = ParamSpec("P") R = TypeVar("R") -def validate_jwt_token(view=None): - def decorator(view): +def validate_jwt_token(view: Optional[Callable[Concatenate[App, EndUser, P], R]] = None): + def decorator(view: Callable[Concatenate[App, EndUser, P], R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): app_model, end_user = decode_jwt_token() - return view(app_model, end_user, *args, **kwargs) return decorated diff --git a/api/core/__init__.py b/api/core/__init__.py index 6eaea7b1c8..e69de29bb2 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -1 +0,0 @@ -import core.moderation.base diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index b94a60c40a..d1d5a011e0 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -72,6 +72,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): function_call_state = True llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" + prompt_messages: list = [] # Initialize prompt_messages + agent_thought_id = "" # Initialize agent_thought_id def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): if not final_llm_usage_dict["usage"]: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 9eb853aa74..5236266908 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -54,6 +54,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): function_call_state = True llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" + prompt_messages: list = [] # Initialize prompt_messages # get tracing instance trace_manager = app_generate_entity.trace_manager diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index a3438fc2c7..1133ecf66c 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -1,4 +1,4 @@ -import enum +from enum import StrEnum from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator @@ -26,25 +26,25 @@ class AgentStrategyProviderIdentity(ToolProviderIdentity): class AgentStrategyParameter(PluginParameter): - class AgentStrategyParameterType(enum.StrEnum): + class AgentStrategyParameterType(StrEnum): """ Keep all the types from PluginParameterType """ - STRING = CommonParameterType.STRING.value - NUMBER = CommonParameterType.NUMBER.value - BOOLEAN = CommonParameterType.BOOLEAN.value - SELECT = CommonParameterType.SELECT.value - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - FILE = CommonParameterType.FILE.value - FILES = CommonParameterType.FILES.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value - ANY = CommonParameterType.ANY.value + STRING = CommonParameterType.STRING + NUMBER = CommonParameterType.NUMBER + BOOLEAN = CommonParameterType.BOOLEAN + SELECT = CommonParameterType.SELECT + SECRET_INPUT = CommonParameterType.SECRET_INPUT + FILE = CommonParameterType.FILE + FILES = CommonParameterType.FILES + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR + ANY = CommonParameterType.ANY # deprecated, should not use. - SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + SYSTEM_FILES = CommonParameterType.SYSTEM_FILES def as_normal_type(self): return as_normal_type(self) @@ -72,7 +72,7 @@ class AgentStrategyIdentity(ToolIdentity): pass -class AgentFeature(enum.StrEnum): +class AgentFeature(StrEnum): """ Agent Feature, used to describe the features of the agent strategy. """ diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 037037e6ca..97ede178c7 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -21,7 +21,7 @@ class SensitiveWordAvoidanceConfigManager: @classmethod def validate_and_set_defaults( - cls, tenant_id, config: dict, only_structure_validate: bool = False + cls, tenant_id: str, config: dict, only_structure_validate: bool = False ) -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = {"enabled": False} @@ -38,7 +38,14 @@ class SensitiveWordAvoidanceConfigManager: if not only_structure_validate: typ = config["sensitive_word_avoidance"]["type"] - sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] + if not isinstance(typ, str): + raise ValueError("sensitive_word_avoidance.type must be a string") + + sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config") + if sensitive_word_avoidance_config is None: + sensitive_word_avoidance_config = {} + if not isinstance(sensitive_word_avoidance_config, dict): + raise ValueError("sensitive_word_avoidance.config must be a dict") ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index e6ab31e586..ec4f6074ab 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -25,10 +25,14 @@ class PromptTemplateConfigManager: if chat_prompt_config: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): + text = message.get("text") + if not isinstance(text, str): + raise ValueError("message text must be a string") + role = message.get("role") + if not isinstance(role, str): + raise ValueError("message role must be a string") chat_prompt_messages.append( - AdvancedChatMessageEntity( - **{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} - ) + AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role)) ) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) @@ -66,7 +70,7 @@ class PromptTemplateConfigManager: :param config: app model config args """ if not config.get("prompt_type"): - config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value + config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] if config["prompt_type"] not in prompt_type_vals: @@ -86,7 +90,7 @@ class PromptTemplateConfigManager: if not isinstance(config["completion_prompt_config"], dict): raise ValueError("completion_prompt_config must be of object type") - if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value: + if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED: if not config["chat_prompt_config"] and not config["completion_prompt_config"]: raise ValueError( "chat_prompt_config or completion_prompt_config is required when prompt_type is advanced" diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index df2074df2c..37745506e9 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import Any, Literal, Optional from pydantic import BaseModel, Field, field_validator @@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel): Prompt Template Entity. """ - class PromptType(Enum): + class PromptType(StrEnum): """ Prompt Type. 'simple', 'advanced' """ - SIMPLE = "simple" - ADVANCED = "advanced" + SIMPLE = auto() + ADVANCED = auto() @classmethod def value_of(cls, value: str): @@ -195,14 +195,14 @@ class DatasetRetrieveConfigEntity(BaseModel): Dataset Retrieve Config Entity. """ - class RetrieveStrategy(Enum): + class RetrieveStrategy(StrEnum): """ Dataset Retrieve Strategy. 'single' or 'multiple' """ - SINGLE = "single" - MULTIPLE = "multiple" + SINGLE = auto() + MULTIPLE = auto() @classmethod def value_of(cls, value: str): @@ -293,12 +293,12 @@ class AppConfig(BaseModel): sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None -class EasyUIBasedAppModelConfigFrom(Enum): +class EasyUIBasedAppModelConfigFrom(StrEnum): """ App Model Config From. """ - ARGS = "args" + ARGS = auto() APP_LATEST_CONFIG = "app-latest-config" CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config" diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 627f6b47ce..02ec96f209 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, Any] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, Any] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): - response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 8207b70f9e..cec3b83674 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline: generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._base_task_pipeline._stream: + if self._base_task_pipeline.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -302,13 +302,13 @@ class AdvancedChatAppGenerateTaskPipeline: def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: """Handle ping events.""" - yield self._base_task_pipeline._ping_stream_response() + yield self._base_task_pipeline.ping_stream_response() def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: """Handle error events.""" with self._database_session() as session: - err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id) - yield self._base_task_pipeline._error_to_stream_response(err) + err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]: """Handle workflow started events.""" @@ -627,10 +627,10 @@ class AdvancedChatAppGenerateTaskPipeline: workflow_execution=workflow_execution, ) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) - err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id) + err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id) yield workflow_finish_resp - yield self._base_task_pipeline._error_to_stream_response(err) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_stop_event( self, @@ -683,7 +683,7 @@ class AdvancedChatAppGenerateTaskPipeline: """Handle advanced chat message end events.""" self._ensure_graph_runtime_initialized(graph_runtime_state) - output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( + output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished( self._task_state.answer ) if output_moderation_answer: @@ -899,7 +899,7 @@ class AdvancedChatAppGenerateTaskPipeline: message.answer = answer_text message.updated_at = naive_utc_now() - message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at + message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at message.message_metadata = self._task_state.metadata.model_dump_json() message_files = [ MessageFile( @@ -955,9 +955,9 @@ class AdvancedChatAppGenerateTaskPipeline: :param text: text :return: True if output moderation should direct output, otherwise False """ - if self._base_task_pipeline._output_moderation_handler: - if self._base_task_pipeline._output_moderation_handler.should_direct_output(): - self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() + if self._base_task_pipeline.output_moderation_handler: + if self._base_task_pipeline.output_moderation_handler.should_direct_output(): + self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output() self._base_task_pipeline.queue_manager.publish( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) @@ -967,7 +967,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) return True else: - self._base_task_pipeline._output_moderation_handler.append_new_token(text) + self._base_task_pipeline.output_moderation_handler.append_new_token(text) return False diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 349b583833..54d1a9595f 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Mapping -from typing import Any, Optional +from typing import Any, Optional, cast from core.agent.entities import AgentEntity from core.app.app_config.base_app_config_manager import BaseAppConfigManager @@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): return filtered_config @classmethod - def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + def validate_agent_mode_and_set_defaults( + cls, tenant_id: str, config: dict[str, Any] + ) -> tuple[dict[str, Any], list[str]]: """ Validate agent_mode and set defaults for agent feature @@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager): if not config.get("agent_mode"): config["agent_mode"] = {"enabled": False, "tools": []} - if not isinstance(config["agent_mode"], dict): + agent_mode = config["agent_mode"] + if not isinstance(agent_mode, dict): raise ValueError("agent_mode must be of object type") - if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: - config["agent_mode"]["enabled"] = False + # FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing + agent_mode = cast(dict[str, Any], agent_mode) - if not isinstance(config["agent_mode"]["enabled"], bool): + if "enabled" not in agent_mode or not agent_mode["enabled"]: + agent_mode["enabled"] = False + + if not isinstance(agent_mode["enabled"], bool): raise ValueError("enabled in agent_mode must be of boolean type") - if not config["agent_mode"].get("strategy"): - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + if not agent_mode.get("strategy"): + agent_mode["strategy"] = PlanningStrategy.ROUTER.value - if config["agent_mode"]["strategy"] not in [ - member.value for member in list(PlanningStrategy.__members__.values()) - ]: + if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: raise ValueError("strategy in agent_mode must be in the specified strategy list") - if not config["agent_mode"].get("tools"): - config["agent_mode"]["tools"] = [] + if not agent_mode.get("tools"): + agent_mode["tools"] = [] - if not isinstance(config["agent_mode"]["tools"], list): + if not isinstance(agent_mode["tools"], list): raise ValueError("tools in agent_mode must be a list of objects") - for tool in config["agent_mode"]["tools"]: + for tool in agent_mode["tools"]: key = list(tool.keys())[0] if key in OLD_TOOLS: # old style, use tool name as key diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 89a5b8e3b5..e35e9d9408 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 795a7befff..2a7fe7902b 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -32,6 +32,7 @@ class AppQueueManager: self._task_id = task_id self._user_id = user_id self._invoke_from = invoke_from + self.invoke_from = invoke_from # Public accessor for invoke_from user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" redis_client.setex( diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 816d6d79a9..3aa1161fd8 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 8485ce7519..843328f904 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator): raise MoreLikeThisDisabledError() app_model_config = message.app_model_config + if not app_model_config: + raise ValueError("Message app_model_config is None") override_model_config_dict = app_model_config.to_dict() model_dict = override_model_config_dict["model"] completion_params = model_dict.get("completion_params") diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 4d45c61145..d7e9ebdf24 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) + if not isinstance(metadata, dict): + metadata = {} sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 210f6110b1..01ecf0298f 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): :param blocking_response: blocking response :return: """ - return dict(blocking_response.to_dict()) + return blocking_response.model_dump() @classmethod def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] @@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, object] = { "event": sub_stream_response.event.value, "workflow_run_id": chunk.workflow_run_id, } @@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, object] = { "event": sub_stream_response.event.value, "workflow_run_id": chunk.workflow_run_id, } @@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 6ab89dbd61..1c950063dd 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -137,7 +137,7 @@ class WorkflowAppGenerateTaskPipeline: self._application_generate_entity = application_generate_entity self._workflow_features_dict = workflow.features_dict self._workflow_run_id = "" - self._invoke_from = queue_manager._invoke_from + self._invoke_from = queue_manager.invoke_from self._draft_var_saver_factory = draft_var_saver_factory def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -146,7 +146,7 @@ class WorkflowAppGenerateTaskPipeline: :return: """ generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._base_task_pipeline._stream: + if self._base_task_pipeline.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -276,12 +276,12 @@ class WorkflowAppGenerateTaskPipeline: def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: """Handle ping events.""" - yield self._base_task_pipeline._ping_stream_response() + yield self._base_task_pipeline.ping_stream_response() def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: """Handle error events.""" - err = self._base_task_pipeline._handle_error(event=event) - yield self._base_task_pipeline._error_to_stream_response(err) + err = self._base_task_pipeline.handle_error(event=event) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_workflow_started_event( self, event: QueueWorkflowStartedEvent, **kwargs diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 72b62eb67c..1d5ebabaf7 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -95,7 +95,7 @@ class AppGenerateEntity(BaseModel): task_id: str # app config - app_config: Any + app_config: Any = None file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] @@ -123,7 +123,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ # app config - app_config: EasyUIBasedAppConfig + app_config: EasyUIBasedAppConfig = None # type: ignore model_conf: ModelConfigWithCredentialsEntity query: Optional[str] = None @@ -186,7 +186,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): """ # app config - app_config: WorkflowUIBasedAppConfig + app_config: WorkflowUIBasedAppConfig = None # type: ignore workflow_run_id: Optional[str] = None query: str @@ -218,7 +218,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): """ # app config - app_config: WorkflowUIBasedAppConfig + app_config: WorkflowUIBasedAppConfig = None # type: ignore workflow_execution_id: str class SingleIterationRunEntity(BaseModel): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index db0297c352..8d20b6bc1b 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from datetime import datetime -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import Any, Optional from pydantic import BaseModel @@ -432,8 +432,8 @@ class QueueAgentLogEvent(AppQueueEvent): id: str label: str node_execution_id: str - parent_id: str | None - error: str | None + parent_id: str | None = None + error: str | None = None status: str data: Mapping[str, Any] metadata: Optional[Mapping[str, Any]] = None @@ -626,15 +626,15 @@ class QueueStopEvent(AppQueueEvent): QueueStopEvent entity """ - class StopBy(Enum): + class StopBy(StrEnum): """ Stop by enum """ - USER_MANUAL = "user-manual" - ANNOTATION_REPLY = "annotation-reply" - OUTPUT_MODERATION = "output-moderation" - INPUT_MODERATION = "input-moderation" + USER_MANUAL = auto() + ANNOTATION_REPLY = auto() + OUTPUT_MODERATION = auto() + INPUT_MODERATION = auto() event: QueueEvent = QueueEvent.STOP stopped_by: StopBy diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index a1c0368354..717e5d715a 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,11 +1,10 @@ from collections.abc import Mapping, Sequence -from enum import Enum +from enum import StrEnum, auto from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -51,37 +50,37 @@ class WorkflowTaskState(TaskState): answer: str = "" -class StreamEvent(Enum): +class StreamEvent(StrEnum): """ Stream event """ - PING = "ping" - ERROR = "error" - MESSAGE = "message" - MESSAGE_END = "message_end" - TTS_MESSAGE = "tts_message" - TTS_MESSAGE_END = "tts_message_end" - MESSAGE_FILE = "message_file" - MESSAGE_REPLACE = "message_replace" - AGENT_THOUGHT = "agent_thought" - AGENT_MESSAGE = "agent_message" - WORKFLOW_STARTED = "workflow_started" - WORKFLOW_FINISHED = "workflow_finished" - NODE_STARTED = "node_started" - NODE_FINISHED = "node_finished" - NODE_RETRY = "node_retry" - PARALLEL_BRANCH_STARTED = "parallel_branch_started" - PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" - ITERATION_STARTED = "iteration_started" - ITERATION_NEXT = "iteration_next" - ITERATION_COMPLETED = "iteration_completed" - LOOP_STARTED = "loop_started" - LOOP_NEXT = "loop_next" - LOOP_COMPLETED = "loop_completed" - TEXT_CHUNK = "text_chunk" - TEXT_REPLACE = "text_replace" - AGENT_LOG = "agent_log" + PING = auto() + ERROR = auto() + MESSAGE = auto() + MESSAGE_END = auto() + TTS_MESSAGE = auto() + TTS_MESSAGE_END = auto() + MESSAGE_FILE = auto() + MESSAGE_REPLACE = auto() + AGENT_THOUGHT = auto() + AGENT_MESSAGE = auto() + WORKFLOW_STARTED = auto() + WORKFLOW_FINISHED = auto() + NODE_STARTED = auto() + NODE_FINISHED = auto() + NODE_RETRY = auto() + PARALLEL_BRANCH_STARTED = auto() + PARALLEL_BRANCH_FINISHED = auto() + ITERATION_STARTED = auto() + ITERATION_NEXT = auto() + ITERATION_COMPLETED = auto() + LOOP_STARTED = auto() + LOOP_NEXT = auto() + LOOP_COMPLETED = auto() + TEXT_CHUNK = auto() + TEXT_REPLACE = auto() + AGENT_LOG = auto() class StreamResponse(BaseModel): @@ -92,9 +91,6 @@ class StreamResponse(BaseModel): event: StreamEvent task_id: str - def to_dict(self): - return jsonable_encoder(self) - class ErrorStreamResponse(StreamResponse): """ @@ -745,9 +741,6 @@ class AppBlockingResponse(BaseModel): task_id: str - def to_dict(self): - return jsonable_encoder(self) - class ChatbotAppBlockingResponse(AppBlockingResponse): """ @@ -828,8 +821,8 @@ class AgentLogStreamResponse(StreamResponse): node_execution_id: str id: str label: str - parent_id: str | None - error: str | None + parent_id: str | None = None + error: str | None = None status: str data: Mapping[str, Any] metadata: Optional[Mapping[str, Any]] = None diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index be183e2086..3853dccdc5 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -35,6 +35,9 @@ class AnnotationReplyFeature: collection_binding_detail = annotation_setting.collection_binding_detail + if not collection_binding_detail: + return None + try: score_threshold = annotation_setting.score_threshold or 1 embedding_provider_name = collection_binding_detail.provider_name diff --git a/api/core/app/features/rate_limiting/__init__.py b/api/core/app/features/rate_limiting/__init__.py index 6624f6ad9d..4ad33acd0f 100644 --- a/api/core/app/features/rate_limiting/__init__.py +++ b/api/core/app/features/rate_limiting/__init__.py @@ -1 +1,3 @@ from .rate_limit import RateLimit + +__all__ = ["RateLimit"] diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index f526d2a16a..6f13f11da0 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,7 +19,7 @@ class RateLimit: _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict: dict[str, "RateLimit"] = {} - def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): + def __new__(cls, client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: instance = super().__new__(cls) cls._instance_dict[client_id] = instance diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 7d98cceb1a..4931300901 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -38,11 +38,11 @@ class BasedGenerateTaskPipeline: ): self._application_generate_entity = application_generate_entity self.queue_manager = queue_manager - self._start_at = time.perf_counter() - self._output_moderation_handler = self._init_output_moderation() - self._stream = stream + self.start_at = time.perf_counter() + self.output_moderation_handler = self._init_output_moderation() + self.stream = stream - def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): + def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): logger.debug("error: %s", event.error) e = event.error err: Exception @@ -86,7 +86,7 @@ class BasedGenerateTaskPipeline: return message - def _error_to_stream_response(self, e: Exception): + def error_to_stream_response(self, e: Exception): """ Error to stream response. :param e: exception @@ -94,7 +94,7 @@ class BasedGenerateTaskPipeline: """ return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) - def _ping_stream_response(self) -> PingStreamResponse: + def ping_stream_response(self) -> PingStreamResponse: """ Ping stream response. :return: @@ -118,21 +118,21 @@ class BasedGenerateTaskPipeline: ) return None - def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: + def handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: """ Handle output moderation when task finished. :param completion: completion :return: """ # response moderation - if self._output_moderation_handler: - self._output_moderation_handler.stop_thread() + if self.output_moderation_handler: + self.output_moderation_handler.stop_thread() - completion, flagged = self._output_moderation_handler.moderation_completion( + completion, flagged = self.output_moderation_handler.moderation_completion( completion=completion, public_event=False ) - self._output_moderation_handler = None + self.output_moderation_handler = None if flagged: return completion diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 0dad0a5a9d..d4d5c9f7d2 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._stream: + if self.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if self._task_state.metadata: extras["metadata"] = self._task_state.metadata.model_dump() response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] - if self._conversation_mode == AppMode.COMPLETION.value: + if self._conversation_mode == AppMode.COMPLETION: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=CompletionAppBlockingResponse.Data( @@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if isinstance(event, QueueErrorEvent): with Session(db.engine) as session: - err = self._handle_error(event=event, session=session, message_id=self._message_id) + err = self.handle_error(event=event, session=session, message_id=self._message_id) session.commit() - yield self._error_to_stream_response(err) + yield self.error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): @@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._handle_stop(event) # handle output moderation - output_moderation_answer = self._handle_output_moderation_when_task_finished( + output_moderation_answer = self.handle_output_moderation_when_task_finished( cast(str, self._task_state.llm_result.message.content) ) if output_moderation_answer: @@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): elif isinstance(event, QueueMessageReplaceEvent): yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self.ping_stream_response() else: continue if publisher: @@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): message.answer_tokens = usage.completion_tokens message.answer_unit_price = usage.completion_unit_price message.answer_price_unit = usage.completion_price_unit - message.provider_response_latency = time.perf_counter() - self._start_at + message.provider_response_latency = time.perf_counter() - self.start_at message.total_price = usage.total_price message.currency = usage.currency self._task_state.llm_result.usage.latency = message.provider_response_latency @@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): # transform usage model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) - self._task_state.llm_result.usage = model_type_instance._calc_response_usage( + self._task_state.llm_result.usage = model_type_instance.calc_response_usage( model, credentials, prompt_tokens, completion_tokens ) @@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): :param text: text :return: True if output moderation should direct output, otherwise False """ - if self._output_moderation_handler: - if self._output_moderation_handler.should_direct_output(): + if self.output_moderation_handler: + if self.output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output - self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() + self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output() self.queue_manager.publish( QueueLLMChunkEvent( chunk=LLMResultChunk( @@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) return True else: - self._output_moderation_handler.append_new_token(text) + self.output_moderation_handler.append_new_token(text) return False diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index e865ba9d60..2cefb61df3 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -92,7 +92,7 @@ class MessageCycleManager: if not conversation: return - if conversation.mode != AppMode.COMPLETION.value: + if conversation.mode != AppMode.COMPLETION: app_model = conversation.app if not app_model: return diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 4e6422e2df..89190c36cc 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -72,7 +72,7 @@ class AppGeneratorTTSPublisher: self.voice = voice if not voice or voice not in values: self.voice = self.voices[0].get("value") - self.MAX_SENTENCE = 2 + self.max_sentence = 2 self._last_audio_event: Optional[AudioTrunk] = None # FIXME better way to handle this threading.start threading.Thread(target=self._runtime).start() @@ -113,8 +113,8 @@ class AppGeneratorTTSPublisher: self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) - if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): - self.MAX_SENTENCE += 1 + if len(sentence_arr) >= min(self.max_sentence, 7): + self.max_sentence += 1 text_content = "".join(sentence_arr) futures_result = self.executor.submit( _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py index 656bf4aa72..cf958b91d2 100644 --- a/api/core/entities/agent_entities.py +++ b/api/core/entities/agent_entities.py @@ -1,8 +1,8 @@ -from enum import Enum +from enum import StrEnum, auto -class PlanningStrategy(Enum): - ROUTER = "router" - REACT_ROUTER = "react_router" - REACT = "react" - FUNCTION_CALL = "function_call" +class PlanningStrategy(StrEnum): + ROUTER = auto() + REACT_ROUTER = auto() + REACT = auto() + FUNCTION_CALL = auto() diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py index 9b4934646b..89b48fd2ef 100644 --- a/api/core/entities/embedding_type.py +++ b/api/core/entities/embedding_type.py @@ -1,10 +1,10 @@ -from enum import Enum +from enum import StrEnum, auto -class EmbeddingInputType(Enum): +class EmbeddingInputType(StrEnum): """ Enum for embedding input type. """ - DOCUMENT = "document" - QUERY = "query" + DOCUMENT = auto() + QUERY = auto() diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 0fd49b059c..4794a691bd 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum +from enum import StrEnum, auto from typing import Optional from pydantic import BaseModel, ConfigDict @@ -9,16 +9,16 @@ from core.model_runtime.entities.model_entities import ModelType, ProviderModel from core.model_runtime.entities.provider_entities import ProviderEntity -class ModelStatus(Enum): +class ModelStatus(StrEnum): """ Enum class for model status. """ - ACTIVE = "active" + ACTIVE = auto() NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" NO_PERMISSION = "no-permission" - DISABLED = "disabled" + DISABLED = auto() CREDENTIAL_REMOVED = "credential-removed" diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index fbd62437e6..0afb51edce 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -1,20 +1,20 @@ -from enum import StrEnum +from enum import StrEnum, auto class CommonParameterType(StrEnum): SECRET_INPUT = "secret-input" TEXT_INPUT = "text-input" - SELECT = "select" - STRING = "string" - NUMBER = "number" - FILE = "file" - FILES = "files" + SELECT = auto() + STRING = auto() + NUMBER = auto() + FILE = auto() + FILES = auto() SYSTEM_FILES = "system-files" - BOOLEAN = "boolean" + BOOLEAN = auto() APP_SELECTOR = "app-selector" MODEL_SELECTOR = "model-selector" TOOLS_SELECTOR = "array[tools]" - ANY = "any" + ANY = auto() # Dynamic select parameter # Once you are not sure about the available options until authorization is done @@ -23,29 +23,29 @@ class CommonParameterType(StrEnum): # TOOL_SELECTOR = "tool-selector" # MCP object and array type parameters - ARRAY = "array" - OBJECT = "object" + ARRAY = auto() + OBJECT = auto() class AppSelectorScope(StrEnum): - ALL = "all" - CHAT = "chat" - WORKFLOW = "workflow" - COMPLETION = "completion" + ALL = auto() + CHAT = auto() + WORKFLOW = auto() + COMPLETION = auto() class ModelSelectorScope(StrEnum): - LLM = "llm" + LLM = auto() TEXT_EMBEDDING = "text-embedding" - RERANK = "rerank" - TTS = "tts" - SPEECH2TEXT = "speech2text" - MODERATION = "moderation" - VISION = "vision" + RERANK = auto() + TTS = auto() + SPEECH2TEXT = auto() + MODERATION = auto() + VISION = auto() class ToolSelectorScope(StrEnum): - ALL = "all" - CUSTOM = "custom" - BUILTIN = "builtin" - WORKFLOW = "workflow" + ALL = auto() + CUSTOM = auto() + BUILTIN = auto() + WORKFLOW = auto() diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 9cf35e559d..5309e4e638 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1840,8 +1840,14 @@ class ProviderConfigurations(BaseModel): def __setitem__(self, key, value): self.configurations[key] = value + def __contains__(self, key): + if "/" not in key: + key = str(ModelProviderID(key)) + return key in self.configurations + def __iter__(self): - return iter(self.configurations) + # Return an iterator of (key, value) tuples to match BaseModel's __iter__ + yield from self.configurations.items() def values(self) -> Iterator[ProviderConfiguration]: return iter(self.configurations.values()) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 9b8baf1973..ad23f4381d 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum, auto from typing import Optional, Union from pydantic import BaseModel, ConfigDict, Field @@ -13,14 +13,14 @@ from core.model_runtime.entities.model_entities import ModelType from core.tools.entities.common_entities import I18nObject -class ProviderQuotaType(Enum): - PAID = "paid" +class ProviderQuotaType(StrEnum): + PAID = auto() """hosted paid quota""" - FREE = "free" + FREE = auto() """third-party free quota""" - TRIAL = "trial" + TRIAL = auto() """hosted trial quota""" @staticmethod @@ -31,20 +31,20 @@ class ProviderQuotaType(Enum): raise ValueError(f"No matching enum found for value '{value}'") -class QuotaUnit(Enum): - TIMES = "times" - TOKENS = "tokens" - CREDITS = "credits" +class QuotaUnit(StrEnum): + TIMES = auto() + TOKENS = auto() + CREDITS = auto() -class SystemConfigurationStatus(Enum): +class SystemConfigurationStatus(StrEnum): """ Enum class for system configuration status. """ - ACTIVE = "active" + ACTIVE = auto() QUOTA_EXCEEDED = "quota-exceeded" - UNSUPPORTED = "unsupported" + UNSUPPORTED = auto() class RestrictModel(BaseModel): @@ -107,7 +107,7 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType - credentials: dict | None + credentials: dict | None = None current_credential_id: Optional[str] = None current_credential_name: Optional[str] = None available_model_credentials: list[CredentialConfiguration] = [] @@ -168,14 +168,14 @@ class BasicProviderConfig(BaseModel): Base model class for common provider settings like credentials """ - class Type(Enum): - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - TEXT_INPUT = CommonParameterType.TEXT_INPUT.value - SELECT = CommonParameterType.SELECT.value - BOOLEAN = CommonParameterType.BOOLEAN.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + class Type(StrEnum): + SECRET_INPUT = CommonParameterType.SECRET_INPUT + TEXT_INPUT = CommonParameterType.TEXT_INPUT + SELECT = CommonParameterType.SELECT + BOOLEAN = CommonParameterType.BOOLEAN + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR @classmethod def value_of(cls, value: str) -> "ProviderConfig.Type": diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index eee914a529..8cb9b4ac58 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -1,8 +1,8 @@ -import enum import importlib.util import json import logging import os +from enum import StrEnum, auto from pathlib import Path from typing import Any, Optional @@ -13,9 +13,9 @@ from core.helper.position_helper import sort_to_dict_by_position_map logger = logging.getLogger(__name__) -class ExtensionModule(enum.Enum): - MODERATION = "moderation" - EXTERNAL_DATA_TOOL = "external_data_tool" +class ExtensionModule(StrEnum): + MODERATION = auto() + EXTERNAL_DATA_TOOL = auto() class ModuleExtension(BaseModel): diff --git a/api/core/file/constants.py b/api/core/file/constants.py index ed1779fd13..0665ed7e0d 100644 --- a/api/core/file/constants.py +++ b/api/core/file/constants.py @@ -9,7 +9,3 @@ FILE_MODEL_IDENTITY = "__dify__file__" def maybe_file_object(o: Any) -> bool: return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY - - -# The default user ID for service API calls. -DEFAULT_SERVICE_API_USER_ID = "DEFAULT-USER" diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index e3fd175d95..2a5f6c3dc7 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -98,7 +98,7 @@ def to_prompt_message_content( def download(f: File, /): if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE): - return _download_file_content(f._storage_key) + return _download_file_content(f.storage_key) elif f.transfer_method == FileTransferMethod.REMOTE_URL: response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() @@ -134,9 +134,9 @@ def _get_encoded_string(f: File, /): response.raise_for_status() data = response.content case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f._storage_key) + data = _download_file_content(f.storage_key) case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f._storage_key) + data = _download_file_content(f.storage_key) encoded_string = base64.b64encode(data).decode("utf-8") return encoded_string diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py index 3ec29fe23d..bf06dbd1ec 100644 --- a/api/core/file/helpers.py +++ b/api/core/file/helpers.py @@ -5,7 +5,6 @@ import os import time from configs import dify_config -from core.file.constants import DEFAULT_SERVICE_API_USER_ID def get_signed_file_url(upload_file_id: str) -> str: @@ -25,10 +24,6 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, # Plugin access should use internal URL for Docker network communication base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL url = f"{base_url}/files/upload/for-plugin" - - if user_id is None: - user_id = DEFAULT_SERVICE_API_USER_ID - timestamp = str(int(time.time())) nonce = os.urandom(16).hex() key = dify_config.SECRET_KEY.encode() @@ -40,11 +35,8 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, def verify_plugin_file_signature( - *, filename: str, mimetype: str, tenant_id: str, user_id: str | None, timestamp: str, nonce: str, sign: str + *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str ) -> bool: - if user_id is None: - user_id = DEFAULT_SERVICE_API_USER_ID - data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" secret_key = dify_config.SECRET_KEY.encode() recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() diff --git a/api/core/file/models.py b/api/core/file/models.py index f61334e7bc..9b74fa387f 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -146,3 +146,11 @@ class File(BaseModel): if not self.related_id: raise ValueError("Missing file related_id") return self + + @property + def storage_key(self) -> str: + return self._storage_key + + @storage_key.setter + def storage_key(self, value: str): + self._storage_key = value diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 1c112007cb..2f160628ca 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -1,12 +1,12 @@ import json -from enum import Enum +from enum import StrEnum from json import JSONDecodeError from typing import Optional from extensions.ext_redis import redis_client -class ProviderCredentialsCacheType(Enum): +class ProviderCredentialsCacheType(StrEnum): PROVIDER = "provider" MODEL = "provider_model" LOAD_BALANCING_MODEL = "load_balancing_provider_model" @@ -14,7 +14,7 @@ class ProviderCredentialsCacheType(Enum): class ProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): - self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" def get(self) -> Optional[dict]: """ diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index 314f052832..2fc8fbf885 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -1,12 +1,14 @@ import os from collections import OrderedDict from collections.abc import Callable +from functools import lru_cache from typing import TypeVar from configs import dify_config -from core.tools.utils.yaml_utils import load_yaml_file +from core.tools.utils.yaml_utils import load_yaml_file_cached +@lru_cache(maxsize=128) def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]: """ Get the mapping from name to index from a YAML file @@ -14,12 +16,17 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> :param file_name: the YAML file name, default to '_position.yaml' :return: a dict with name as key and index as value """ + # FIXME(-LAN-): Cache position maps to prevent file descriptor exhaustion during high-load benchmarks position_file_path = os.path.join(folder_path, file_name) - yaml_content = load_yaml_file(file_path=position_file_path, default_value=[]) + try: + yaml_content = load_yaml_file_cached(file_path=position_file_path) + except Exception: + yaml_content = [] positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()] return {name: index for index, name in enumerate(positions)} +@lru_cache(maxsize=128) def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: """ Get the mapping for tools from name to index from a YAML file. @@ -35,20 +42,6 @@ def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") - ) -def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: - """ - Get the mapping for providers from name to index from a YAML file. - :param folder_path: - :param file_name: the YAML file name, default to '_position.yaml' - :return: a dict with name as key and index as value - """ - position_map = get_position_map(folder_path, file_name=file_name) - return pin_position_map( - position_map, - pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, - ) - - def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]: """ Pin the items in the pin list to the beginning of the position map. diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index efeba9e5ee..cbb78939d2 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -13,18 +13,18 @@ logger = logging.getLogger(__name__) SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES -HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True +http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True try: - HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY - http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower() + config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY + http_request_node_ssl_verify_lower = str(config_value).lower() if http_request_node_ssl_verify_lower == "true": - HTTP_REQUEST_NODE_SSL_VERIFY = True + http_request_node_ssl_verify = True elif http_request_node_ssl_verify_lower == "false": - HTTP_REQUEST_NODE_SSL_VERIFY = False + http_request_node_ssl_verify = False else: raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'") except NameError: - HTTP_REQUEST_NODE_SSL_VERIFY = True + http_request_node_ssl_verify = True BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] @@ -51,7 +51,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): ) if "ssl_verify" not in kwargs: - kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY + kwargs["ssl_verify"] = http_request_node_ssl_verify ssl_verify = kwargs.pop("ssl_verify") diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index 95a1086ca8..c2cc2e4db0 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -1,12 +1,12 @@ import json -from enum import Enum +from enum import StrEnum from json import JSONDecodeError from typing import Optional from extensions.ext_redis import redis_client -class ToolParameterCacheType(Enum): +class ToolParameterCacheType(StrEnum): PARAMETER = "tool_parameter" @@ -15,7 +15,7 @@ class ToolParameterCache: self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str ): self.cache_key = ( - f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" + f"{cache_type}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" f":identity_id:{identity_id}" ) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 89a05e02c8..ed02b70b03 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -529,6 +529,7 @@ class IndexingRunner: # chunk nodes by chunk size indexing_start_at = time.perf_counter() tokens = 0 + create_keyword_thread = None if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": # create keyword index create_keyword_thread = threading.Thread( @@ -567,7 +568,11 @@ class IndexingRunner: for future in futures: tokens += future.result() - if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": + if ( + dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX + and dataset.indexing_technique == "economy" + and create_keyword_thread is not None + ): create_keyword_thread.join() indexing_end_at = time.perf_counter() diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 94b8258e9c..d4c4f10a12 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -20,7 +20,7 @@ from core.llm_generator.prompts import ( ) from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.entities.trace_entity import TraceTaskName @@ -313,14 +313,20 @@ class LLMGenerator: model_type=ModelType.LLM, ) - prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] + prompt_messages: list[PromptMessage] = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] - response: LLMResult = model_instance.invoke_llm( + # Explicitly use the non-streaming overload + result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False, ) + # Runtime type check since pyright has issues with the overload + if not isinstance(result, LLMResult): + raise TypeError("Expected LLMResult when stream=False") + response = result + answer = cast(str, response.message.content) return answer.strip() diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 28833fe8e8..e0b70c132f 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -45,6 +45,7 @@ class SpecialModelType(StrEnum): @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, @@ -53,14 +54,13 @@ def invoke_llm_with_structured_output( model_parameters: Optional[Mapping] = None, tools: Sequence[PromptMessageTool] | None = None, stop: Optional[list[str]] = None, - stream: Literal[True] = True, + stream: Literal[True], user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, @@ -69,14 +69,13 @@ def invoke_llm_with_structured_output( model_parameters: Optional[Mapping] = None, tools: Sequence[PromptMessageTool] | None = None, stop: Optional[list[str]] = None, - stream: Literal[False] = False, + stream: Literal[False], user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> LLMResultWithStructuredOutput: ... - - @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, @@ -89,9 +88,8 @@ def invoke_llm_with_structured_output( user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index cc4263c0aa..6db22a09e0 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -23,13 +23,13 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3 @final class _StatusReady: def __init__(self, endpoint_url: str): - self._endpoint_url = endpoint_url + self.endpoint_url = endpoint_url @final class _StatusError: def __init__(self, exc: Exception): - self._exc = exc + self.exc = exc # Type aliases for better readability @@ -211,9 +211,9 @@ class SSETransport: raise ValueError("failed to get endpoint URL") if isinstance(status, _StatusReady): - return status._endpoint_url + return status.endpoint_url elif isinstance(status, _StatusError): - raise status._exc + raise status.exc else: raise ValueError("failed to get endpoint URL") diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 3d51ac2333..212c2eb073 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -38,6 +38,7 @@ def handle_mcp_request( """ request_type = type(request.root) + request_root = request.root def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse: """Create success response with business result data""" @@ -58,21 +59,20 @@ def handle_mcp_request( error=error_data, ) - # Request handler mapping using functional approach - request_handlers = { - mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description), - mcp_types.ListToolsRequest: lambda: handle_list_tools( - app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict - ), - mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user), - mcp_types.PingRequest: lambda: handle_ping(), - } - try: - # Dispatch request to appropriate handler - handler = request_handlers.get(request_type) - if handler: - return create_success_response(handler()) + # Dispatch request to appropriate handler based on instance type + if isinstance(request_root, mcp_types.InitializeRequest): + return create_success_response(handle_initialize(mcp_server.description)) + elif isinstance(request_root, mcp_types.ListToolsRequest): + return create_success_response( + handle_list_tools( + app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict + ) + ) + elif isinstance(request_root, mcp_types.CallToolRequest): + return create_success_response(handle_call_tool(app, request, user_input_form, end_user)) + elif isinstance(request_root, mcp_types.PingRequest): + return create_success_response(handle_ping()) else: return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}") @@ -142,7 +142,7 @@ def handle_call_tool( end_user, args, InvokeFrom.SERVICE_API, - streaming=app.mode == AppMode.AGENT_CHAT.value, + streaming=app.mode == AppMode.AGENT_CHAT, ) answer = extract_answer_from_response(app, response) @@ -157,7 +157,7 @@ def build_parameter_schema( """Build parameter schema for the tool""" parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict) - if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}: + if app_mode in {AppMode.COMPLETION, AppMode.WORKFLOW}: return { "type": "object", "properties": parameters, @@ -175,9 +175,9 @@ def build_parameter_schema( def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]: """Prepare arguments based on app mode""" - if app.mode == AppMode.WORKFLOW.value: + if app.mode == AppMode.WORKFLOW: return {"inputs": arguments} - elif app.mode == AppMode.COMPLETION.value: + elif app.mode == AppMode.COMPLETION: return {"query": "", "inputs": arguments} else: # Chat modes - create a copy to avoid modifying original dict @@ -218,13 +218,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str: def process_mapping_response(app: App, response: Mapping) -> str: """Process mapping response based on app mode""" if app.mode in { - AppMode.ADVANCED_CHAT.value, - AppMode.COMPLETION.value, - AppMode.CHAT.value, - AppMode.AGENT_CHAT.value, + AppMode.ADVANCED_CHAT, + AppMode.COMPLETION, + AppMode.CHAT, + AppMode.AGENT_CHAT, }: return response.get("answer", "") - elif app.mode == AppMode.WORKFLOW.value: + elif app.mode == AppMode.WORKFLOW: return json.dumps(response["data"]["outputs"], ensure_ascii=False) else: raise ValueError("Invalid app mode: " + str(app.mode)) diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 96c48034c7..fbad5576aa 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -81,7 +81,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): self.request_meta = request_meta self.request = request self._session = session - self._completed = False + self.completed = False self._on_complete = on_complete self._entered = False # Track if we're in a context manager @@ -98,7 +98,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): ): """Exit the context manager, performing cleanup and notifying completion.""" try: - if self._completed: + if self.completed: self._on_complete(self) finally: self._entered = False @@ -113,9 +113,9 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """ if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") - assert not self._completed, "Request already responded to" + assert not self.completed, "Request already responded to" - self._completed = True + self.completed = True self._session._send_response(request_id=self.request_id, response=response) @@ -124,7 +124,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") - self._completed = True # Mark as completed so it's removed from in_flight + self.completed = True # Mark as completed so it's removed from in_flight # Send an error response to indicate cancellation self._session._send_response( request_id=self.request_id, @@ -351,7 +351,7 @@ class BaseSession( self._in_flight[responder.request_id] = responder self._received_request(responder) - if not responder._completed: + if not responder.completed: self._handle_incoming(responder) elif isinstance(message.message.root, JSONRPCNotification): diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 49aa8e4498..a2c3157b3b 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -809,7 +809,7 @@ class LoggingMessageNotificationParams(NotificationParams): """The severity of this log message.""" logger: str | None = None """An optional name of the logger issuing this message.""" - data: Any + data: Any = None """ The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 7be695812a..1f2525cfed 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -32,11 +32,16 @@ class TokenBufferMemory: self.model_instance = model_instance def _build_prompt_message_with_files( - self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool + self, + message_files: Sequence[MessageFile], + text_content: str, + message: Message, + app_record, + is_user_message: bool, ) -> PromptMessage: """ Build prompt message with files. - :param message_files: list of MessageFile objects + :param message_files: Sequence of MessageFile objects :param text_content: text content of the message :param message: Message object :param app_record: app record @@ -128,14 +133,12 @@ class TokenBufferMemory: prompt_messages: list[PromptMessage] = [] for message in messages: # Process user message with files - user_files = ( - db.session.query(MessageFile) - .where( + user_files = db.session.scalars( + select(MessageFile).where( MessageFile.message_id == message.id, (MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)), ) - .all() - ) + ).all() if user_files: user_prompt_message = self._build_prompt_message_with_files( @@ -150,11 +153,9 @@ class TokenBufferMemory: prompt_messages.append(UserPromptMessage(content=message.query)) # Process assistant message with files - assistant_files = ( - db.session.query(MessageFile) - .where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant") - .all() - ) + assistant_files = db.session.scalars( + select(MessageFile).where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant") + ).all() if assistant_files: assistant_prompt_message = self._build_prompt_message_with_files( diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 7cd2e6a3d1..75ffe7c32f 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,20 +1,20 @@ from abc import ABC from collections.abc import Mapping, Sequence -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import Annotated, Any, Literal, Optional, Union from pydantic import BaseModel, Field, field_serializer, field_validator -class PromptMessageRole(Enum): +class PromptMessageRole(StrEnum): """ Enum class for prompt message. """ - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - TOOL = "tool" + SYSTEM = auto() + USER = auto() + ASSISTANT = auto() + TOOL = auto() @classmethod def value_of(cls, value: str) -> "PromptMessageRole": @@ -54,11 +54,11 @@ class PromptMessageContentType(StrEnum): Enum class for prompt message content type. """ - TEXT = "text" - IMAGE = "image" - AUDIO = "audio" - VIDEO = "video" - DOCUMENT = "document" + TEXT = auto() + IMAGE = auto() + AUDIO = auto() + VIDEO = auto() + DOCUMENT = auto() class PromptMessageContent(ABC, BaseModel): @@ -108,8 +108,8 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent): """ class DETAIL(StrEnum): - LOW = "low" - HIGH = "high" + LOW = auto() + HIGH = auto() type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 568149cc37..8259335c50 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,5 +1,5 @@ from decimal import Decimal -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import Any, Optional from pydantic import BaseModel, ConfigDict, model_validator @@ -7,17 +7,17 @@ from pydantic import BaseModel, ConfigDict, model_validator from core.model_runtime.entities.common_entities import I18nObject -class ModelType(Enum): +class ModelType(StrEnum): """ Enum class for model type. """ - LLM = "llm" + LLM = auto() TEXT_EMBEDDING = "text-embedding" - RERANK = "rerank" - SPEECH2TEXT = "speech2text" - MODERATION = "moderation" - TTS = "tts" + RERANK = auto() + SPEECH2TEXT = auto() + MODERATION = auto() + TTS = auto() @classmethod def value_of(cls, origin_model_type: str) -> "ModelType": @@ -26,17 +26,17 @@ class ModelType(Enum): :return: model type """ - if origin_model_type in {"text-generation", cls.LLM.value}: + if origin_model_type in {"text-generation", cls.LLM}: return cls.LLM - elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}: + elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}: return cls.TEXT_EMBEDDING - elif origin_model_type in {"reranking", cls.RERANK.value}: + elif origin_model_type in {"reranking", cls.RERANK}: return cls.RERANK - elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}: + elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}: return cls.SPEECH2TEXT - elif origin_model_type in {"tts", cls.TTS.value}: + elif origin_model_type in {"tts", cls.TTS}: return cls.TTS - elif origin_model_type == cls.MODERATION.value: + elif origin_model_type == cls.MODERATION: return cls.MODERATION else: raise ValueError(f"invalid origin model type {origin_model_type}") @@ -63,7 +63,7 @@ class ModelType(Enum): raise ValueError(f"invalid model type {self}") -class FetchFrom(Enum): +class FetchFrom(StrEnum): """ Enum class for fetch from. """ @@ -72,7 +72,7 @@ class FetchFrom(Enum): CUSTOMIZABLE_MODEL = "customizable-model" -class ModelFeature(Enum): +class ModelFeature(StrEnum): """ Enum class for llm feature. """ @@ -80,11 +80,11 @@ class ModelFeature(Enum): TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" - VISION = "vision" + VISION = auto() STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = "document" - VIDEO = "video" - AUDIO = "audio" + DOCUMENT = auto() + VIDEO = auto() + AUDIO = auto() STRUCTURED_OUTPUT = "structured-output" @@ -93,14 +93,14 @@ class DefaultParameterName(StrEnum): Enum class for parameter template variable. """ - TEMPERATURE = "temperature" - TOP_P = "top_p" - TOP_K = "top_k" - PRESENCE_PENALTY = "presence_penalty" - FREQUENCY_PENALTY = "frequency_penalty" - MAX_TOKENS = "max_tokens" - RESPONSE_FORMAT = "response_format" - JSON_SCHEMA = "json_schema" + TEMPERATURE = auto() + TOP_P = auto() + TOP_K = auto() + PRESENCE_PENALTY = auto() + FREQUENCY_PENALTY = auto() + MAX_TOKENS = auto() + RESPONSE_FORMAT = auto() + JSON_SCHEMA = auto() @classmethod def value_of(cls, value: Any) -> "DefaultParameterName": @@ -116,34 +116,34 @@ class DefaultParameterName(StrEnum): raise ValueError(f"invalid parameter name {value}") -class ParameterType(Enum): +class ParameterType(StrEnum): """ Enum class for parameter type. """ - FLOAT = "float" - INT = "int" - STRING = "string" - BOOLEAN = "boolean" - TEXT = "text" + FLOAT = auto() + INT = auto() + STRING = auto() + BOOLEAN = auto() + TEXT = auto() -class ModelPropertyKey(Enum): +class ModelPropertyKey(StrEnum): """ Enum class for model property key. """ - MODE = "mode" - CONTEXT_SIZE = "context_size" - MAX_CHUNKS = "max_chunks" - FILE_UPLOAD_LIMIT = "file_upload_limit" - SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions" - MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk" - DEFAULT_VOICE = "default_voice" - VOICES = "voices" - WORD_LIMIT = "word_limit" - AUDIO_TYPE = "audio_type" - MAX_WORKERS = "max_workers" + MODE = auto() + CONTEXT_SIZE = auto() + MAX_CHUNKS = auto() + FILE_UPLOAD_LIMIT = auto() + SUPPORTED_FILE_EXTENSIONS = auto() + MAX_CHARACTERS_PER_CHUNK = auto() + DEFAULT_VOICE = auto() + VOICES = auto() + WORD_LIMIT = auto() + AUDIO_TYPE = auto() + MAX_WORKERS = auto() class ProviderModel(BaseModel): @@ -220,13 +220,13 @@ class ModelUsage(BaseModel): pass -class PriceType(Enum): +class PriceType(StrEnum): """ Enum class for price type. """ - INPUT = "input" - OUTPUT = "output" + INPUT = auto() + OUTPUT = auto() class PriceInfo(BaseModel): diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index c9aa8d1474..451c2359b3 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum +from enum import Enum, StrEnum, auto from typing import Optional from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -17,16 +17,16 @@ class ConfigurateMethod(Enum): CUSTOMIZABLE_MODEL = "customizable-model" -class FormType(Enum): +class FormType(StrEnum): """ Enum class for form type. """ TEXT_INPUT = "text-input" SECRET_INPUT = "secret-input" - SELECT = "select" - RADIO = "radio" - SWITCH = "switch" + SELECT = auto() + RADIO = auto() + SWITCH = auto() class FormShowOnObject(BaseModel): diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 24b206fdbe..1d7fd7d447 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -354,7 +354,7 @@ class LargeLanguageModel(AIModel): ) return 0 - def _calc_response_usage( + def calc_response_usage( self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int ) -> LLMUsage: """ diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index f7bba0eba1..3ce438955a 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -47,7 +47,7 @@ class TextEmbeddingModel(AIModel): model=model, credentials=credentials, texts=texts, - input_type=input_type.value, + input_type=input_type, ) except Exception as e: raise self._transform_invoke_error(e) diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 6502b920f5..35c1dfb144 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,14 +1,10 @@ import hashlib import logging -import os from collections.abc import Sequence from threading import Lock from typing import Optional -from pydantic import BaseModel - import contexts -from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -28,48 +24,20 @@ from core.plugin.impl.model import PluginModelClient logger = logging.getLogger(__name__) -class ModelProviderExtension(BaseModel): - plugin_model_provider_entity: PluginModelProviderEntity - position: Optional[int] = None - - class ModelProviderFactory: - provider_position_map: dict[str, int] - def __init__(self, tenant_id: str): - self.provider_position_map = {} - self.tenant_id = tenant_id self.plugin_model_manager = PluginModelClient() - if not self.provider_position_map: - # get the path of current classes - current_path = os.path.abspath(__file__) - model_providers_path = os.path.dirname(current_path) - - # get _position.yaml file path - self.provider_position_map = get_provider_position_map(model_providers_path) - def get_providers(self) -> Sequence[ProviderEntity]: """ Get all providers :return: list of providers """ - # Fetch plugin model providers + # FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server + # The plugin server should return providers in the desired order plugin_providers = self.get_plugin_model_providers() - - # Convert PluginModelProviderEntity to ModelProviderExtension - model_provider_extensions = [] - for provider in plugin_providers: - model_provider_extensions.append(ModelProviderExtension(plugin_model_provider_entity=provider)) - - sorted_extensions = sort_to_dict_by_position_map( - position_map=self.provider_position_map, - data=model_provider_extensions, - name_func=lambda x: x.plugin_model_provider_entity.declaration.provider, - ) - - return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()] + return [provider.declaration for provider in plugin_providers] def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: """ diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index 962e417671..f65339fbfc 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -18,7 +18,7 @@ from pydantic_core import Url from pydantic_extra_types.color import Color -def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any): +def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: return model.model_dump(mode=mode, **kwargs) @@ -100,7 +100,7 @@ def jsonable_encoder( exclude_none: bool = False, custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, sqlalchemy_safe: bool = True, -): +) -> Any: custom_encoder = custom_encoder or {} if custom_encoder: if type(obj) in custom_encoder: diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 752617b654..340483c894 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from enum import Enum +from enum import StrEnum, auto from typing import Optional from pydantic import BaseModel, Field @@ -7,9 +7,9 @@ from pydantic import BaseModel, Field from core.extension.extensible import Extensible, ExtensionModule -class ModerationAction(Enum): - DIRECT_OUTPUT = "direct_output" - OVERRIDDEN = "overridden" +class ModerationAction(StrEnum): + DIRECT_OUTPUT = auto() + OVERRIDDEN = auto() class ModerationInputsResult(BaseModel): diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/core/ops/aliyun_trace/entities/semconv.py index 5d70264320..c9427c776a 100644 --- a/api/core/ops/aliyun_trace/entities/semconv.py +++ b/api/core/ops/aliyun_trace/entities/semconv.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum # public GEN_AI_SESSION_ID = "gen_ai.session.id" @@ -53,7 +53,7 @@ TOOL_DESCRIPTION = "tool.description" TOOL_PARAMETERS = "tool.parameters" -class GenAISpanKind(Enum): +class GenAISpanKind(StrEnum): CHAIN = "CHAIN" RETRIEVER = "RETRIEVER" RERANKER = "RERANKER" diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index e7c90c1229..c5fbe4d78b 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -15,6 +15,7 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.id_generator import RandomIdGenerator from opentelemetry.trace import SpanContext, TraceFlags, TraceState +from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig @@ -699,8 +700,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def _get_workflow_nodes(self, workflow_run_id: str): """Helper method to get workflow nodes""" - workflow_nodes = ( - db.session.query( + workflow_nodes = db.session.scalars( + select( WorkflowNodeExecutionModel.id, WorkflowNodeExecutionModel.tenant_id, WorkflowNodeExecutionModel.app_id, @@ -713,10 +714,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): WorkflowNodeExecutionModel.elapsed_time, WorkflowNodeExecutionModel.process_data, WorkflowNodeExecutionModel.execution_metadata, - ) - .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) - .all() - ) + ).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) + ).all() return workflow_nodes def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 3bad5c92fb..1870da3781 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -35,7 +35,7 @@ class BaseTraceInfo(BaseModel): class WorkflowTraceInfo(BaseTraceInfo): - workflow_data: Any + workflow_data: Any = None conversation_id: Optional[str] = None workflow_app_log_id: Optional[str] = None workflow_id: str @@ -89,7 +89,7 @@ class SuggestedQuestionTraceInfo(BaseTraceInfo): class DatasetRetrievalTraceInfo(BaseTraceInfo): - documents: Any + documents: Any = None class ToolTraceInfo(BaseTraceInfo): @@ -97,12 +97,12 @@ class ToolTraceInfo(BaseTraceInfo): tool_inputs: dict[str, Any] tool_outputs: str metadata: dict[str, Any] - message_file_data: Any + message_file_data: Any = None error: Optional[str] = None tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] - file_url: Union[str, None, list] + file_url: Union[str, None, list] = None class GenerateNameTraceInfo(BaseTraceInfo): @@ -113,7 +113,7 @@ class GenerateNameTraceInfo(BaseTraceInfo): class TaskData(BaseModel): app_id: str trace_info_type: str - trace_info: Any + trace_info: Any = None trace_info_info_map = { diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 48f44da68e..86cc3839f2 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -27,7 +27,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): app = cls._get_app(app_id, tenant_id) """Retrieve app parameters.""" - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app.workflow if workflow is None: raise ValueError("unexpected app type") @@ -70,7 +70,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): conversation_id = conversation_id or "" - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}: if not query: raise ValueError("missing query") @@ -96,7 +96,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ invoke chat app """ - if app.mode == AppMode.ADVANCED_CHAT.value: + if app.mode == AppMode.ADVANCED_CHAT: workflow = app.workflow if not workflow: raise ValueError("unexpected app type") @@ -114,7 +114,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): invoke_from=InvokeFrom.SERVICE_API, streaming=stream, ) - elif app.mode == AppMode.AGENT_CHAT.value: + elif app.mode == AppMode.AGENT_CHAT: return AgentChatAppGenerator().generate( app_model=app, user=user, @@ -127,7 +127,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): invoke_from=InvokeFrom.SERVICE_API, streaming=stream, ) - elif app.mode == AppMode.CHAT.value: + elif app.mode == AppMode.CHAT: return ChatAppGenerator().generate( app_model=app, user=user, diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 47290ee613..01bb011ce7 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -1,4 +1,5 @@ -import enum +import json +from enum import StrEnum, auto from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -24,44 +25,44 @@ class PluginParameterOption(BaseModel): return value -class PluginParameterType(enum.StrEnum): +class PluginParameterType(StrEnum): """ all available parameter types """ - STRING = CommonParameterType.STRING.value - NUMBER = CommonParameterType.NUMBER.value - BOOLEAN = CommonParameterType.BOOLEAN.value - SELECT = CommonParameterType.SELECT.value - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - FILE = CommonParameterType.FILE.value - FILES = CommonParameterType.FILES.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value - ANY = CommonParameterType.ANY.value - DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value + STRING = CommonParameterType.STRING + NUMBER = CommonParameterType.NUMBER + BOOLEAN = CommonParameterType.BOOLEAN + SELECT = CommonParameterType.SELECT + SECRET_INPUT = CommonParameterType.SECRET_INPUT + FILE = CommonParameterType.FILE + FILES = CommonParameterType.FILES + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR + ANY = CommonParameterType.ANY + DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT # deprecated, should not use. - SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + SYSTEM_FILES = CommonParameterType.SYSTEM_FILES # MCP object and array type parameters - ARRAY = CommonParameterType.ARRAY.value - OBJECT = CommonParameterType.OBJECT.value + ARRAY = CommonParameterType.ARRAY + OBJECT = CommonParameterType.OBJECT -class MCPServerParameterType(enum.StrEnum): +class MCPServerParameterType(StrEnum): """ MCP server got complex parameter types """ - ARRAY = "array" - OBJECT = "object" + ARRAY = auto() + OBJECT = auto() class PluginParameterAutoGenerate(BaseModel): - class Type(enum.StrEnum): - PROMPT_INSTRUCTION = "prompt_instruction" + class Type(StrEnum): + PROMPT_INSTRUCTION = auto() type: Type @@ -92,7 +93,7 @@ class PluginParameter(BaseModel): return v -def as_normal_type(typ: enum.StrEnum): +def as_normal_type(typ: StrEnum): if typ.value in { PluginParameterType.SECRET_INPUT, PluginParameterType.SELECT, @@ -101,7 +102,7 @@ def as_normal_type(typ: enum.StrEnum): return typ.value -def cast_parameter_value(typ: enum.StrEnum, value: Any, /): +def cast_parameter_value(typ: StrEnum, value: Any, /): try: match typ.value: case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT: @@ -162,8 +163,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): # Try to parse JSON string for arrays if isinstance(value, str): try: - import json - parsed_value = json.loads(value) if isinstance(parsed_value, list): return parsed_value @@ -176,8 +175,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): # Try to parse JSON string for objects if isinstance(value, str): try: - import json - parsed_value = json.loads(value) if isinstance(parsed_value, dict): return parsed_value @@ -193,7 +190,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.") -def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: Any): +def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any): """ init frontend parameter by rule """ diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index a6369636e2..261d97c2b6 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -1,7 +1,7 @@ import datetime -import enum import re from collections.abc import Mapping +from enum import StrEnum, auto from typing import Any, Optional from packaging.version import InvalidVersion, Version @@ -16,11 +16,11 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity -class PluginInstallationSource(enum.StrEnum): - Github = "github" - Marketplace = "marketplace" - Package = "package" - Remote = "remote" +class PluginInstallationSource(StrEnum): + Github = auto() + Marketplace = auto() + Package = auto() + Remote = auto() class PluginResourceRequirements(BaseModel): @@ -58,10 +58,10 @@ class PluginResourceRequirements(BaseModel): permission: Optional[Permission] = Field(default=None) -class PluginCategory(enum.StrEnum): - Tool = "tool" - Model = "model" - Extension = "extension" +class PluginCategory(StrEnum): + Tool = auto() + Model = auto() + Extension = auto() AgentStrategy = "agent-strategy" @@ -206,10 +206,10 @@ class ToolProviderID(GenericProviderID): class PluginDependency(BaseModel): - class Type(enum.StrEnum): - Github = PluginInstallationSource.Github.value - Marketplace = PluginInstallationSource.Marketplace.value - Package = PluginInstallationSource.Package.value + class Type(StrEnum): + Github = PluginInstallationSource.Github + Marketplace = PluginInstallationSource.Marketplace + Package = PluginInstallationSource.Package class Github(BaseModel): repo: str diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 16ab661092..f1d6860bb4 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -24,7 +24,7 @@ class PluginDaemonBasicResponse(BaseModel, Generic[T]): code: int message: str - data: Optional[T] + data: Optional[T] = None class InstallPluginMessage(BaseModel): diff --git a/api/core/plugin/utils/chunk_merger.py b/api/core/plugin/utils/chunk_merger.py index ec66ba02ee..e30076f9d3 100644 --- a/api/core/plugin/utils/chunk_merger.py +++ b/api/core/plugin/utils/chunk_merger.py @@ -82,7 +82,9 @@ def merge_blob_chunks( message_class = type(resp) merged_message = message_class( type=ToolInvokeMessage.MessageType.BLOB, - message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]), + message=ToolInvokeMessage.BlobMessage( + blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written]) + ), meta=resp.meta, ) yield cast(MessageType, merged_message) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d75a230d73..fdab0af103 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,7 +1,7 @@ -import enum import json import os from collections.abc import Mapping, Sequence +from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, Optional, cast from core.app.app_config.entities import PromptTemplateEntity @@ -25,9 +25,9 @@ if TYPE_CHECKING: from core.file.models import File -class ModelMode(enum.StrEnum): - COMPLETION = "completion" - CHAT = "chat" +class ModelMode(StrEnum): + COMPLETION = auto() + CHAT = auto() prompt_file_contents: dict[str, Any] = {} @@ -101,9 +101,22 @@ class SimplePromptTransform(PromptTransform): with_memory_prompt=histories is not None, ) - variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs} + custom_variable_keys_obj = prompt_template_config["custom_variable_keys"] + special_variable_keys_obj = prompt_template_config["special_variable_keys"] - for v in prompt_template_config["special_variable_keys"]: + # Type check for custom_variable_keys + if not isinstance(custom_variable_keys_obj, list): + raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}") + custom_variable_keys = cast(list[str], custom_variable_keys_obj) + + # Type check for special_variable_keys + if not isinstance(special_variable_keys_obj, list): + raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}") + special_variable_keys = cast(list[str], special_variable_keys_obj) + + variables = {k: inputs[k] for k in custom_variable_keys if k in inputs} + + for v in special_variable_keys: # support #context#, #query# and #histories# if v == "#context#": variables["#context#"] = context or "" @@ -113,9 +126,16 @@ class SimplePromptTransform(PromptTransform): variables["#histories#"] = histories or "" prompt_template = prompt_template_config["prompt_template"] + if not isinstance(prompt_template, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template)}") + prompt = prompt_template.format(variables) - return prompt, prompt_template_config["prompt_rules"] + prompt_rules = prompt_template_config["prompt_rules"] + if not isinstance(prompt_rules, dict): + raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}") + + return prompt, prompt_rules def get_prompt_template( self, @@ -126,11 +146,11 @@ class SimplePromptTransform(PromptTransform): has_context: bool, query_in_prompt: bool, with_memory_prompt: bool = False, - ): + ) -> dict[str, object]: prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) - custom_variable_keys = [] - special_variable_keys = [] + custom_variable_keys: list[str] = [] + special_variable_keys: list[str] = [] prompt = "" for order in prompt_rules["system_prompt_orders"]: diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py index 9887e21b7c..8fc94be360 100644 --- a/api/core/rag/datasource/vdb/field.py +++ b/api/core/rag/datasource/vdb/field.py @@ -1,13 +1,13 @@ -from enum import Enum +from enum import StrEnum, auto -class Field(Enum): +class Field(StrEnum): CONTENT_KEY = "page_content" METADATA_KEY = "metadata" GROUP_KEY = "group_id" - VECTOR = "vector" + VECTOR = auto() # Sparse Vector aims to support full text search - SPARSE_VECTOR = "sparse_vector" + SPARSE_VECTOR = auto() TEXT_KEY = "text" PRIMARY_KEY = "id" DOC_ID = "metadata.doc_id" diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index 107ea75e6a..0eca37a129 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -28,8 +28,8 @@ def create_ssl_context() -> ssl.SSLContext: class HuaweiCloudVectorConfig(BaseModel): hosts: str - username: str | None - password: str | None + username: str | None = None + password: str | None = None @model_validator(mode="before") @classmethod diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 7da830f643..3dd073ce50 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -1,8 +1,9 @@ import json import logging import uuid +from collections.abc import Callable from functools import wraps -from typing import Any, Optional +from typing import Any, Concatenate, Optional, ParamSpec, TypeVar from mo_vector.client import MoVectorClient # type: ignore from pydantic import BaseModel, model_validator @@ -17,7 +18,6 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset logger = logging.getLogger(__name__) -from typing import ParamSpec, TypeVar P = ParamSpec("P") R = TypeVar("R") @@ -47,16 +47,6 @@ class MatrixoneConfig(BaseModel): return values -def ensure_client(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if self.client is None: - self.client = self._get_client(None, False) - return func(self, *args, **kwargs) - - return wrapper - - class MatrixoneVector(BaseVector): """ Matrixone vector storage implementation. @@ -216,6 +206,19 @@ class MatrixoneVector(BaseVector): self.client.delete() +T = TypeVar("T", bound=MatrixoneVector) + + +def ensure_client(func: Callable[Concatenate[T, P], R]): + @wraps(func) + def wrapper(self: T, *args: P.args, **kwargs: P.kwargs): + if self.client is None: + self.client = self._get_client(None, False) + return func(self, *args, **kwargs) + + return wrapper + + class MatrixoneVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector: if dataset.index_struct_dict: diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b590a4dfe4..17aac25b87 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -1,7 +1,7 @@ import json import logging import uuid -from enum import Enum +from enum import StrEnum from typing import Any from clickhouse_connect import get_client @@ -27,7 +27,7 @@ class MyScaleConfig(BaseModel): fts_params: str -class SortOrder(Enum): +class SortOrder(StrEnum): ASC = "ASC" DESC = "DESC" diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 12d97c500f..d329220580 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -40,6 +40,19 @@ if TYPE_CHECKING: MetadataFilter = Union[DictFilter, common_types.Filter] +class PathQdrantParams(BaseModel): + path: str + + +class UrlQdrantParams(BaseModel): + url: str + api_key: Optional[str] + timeout: float + verify: bool + grpc_port: int + prefer_grpc: bool + + class QdrantConfig(BaseModel): endpoint: str api_key: Optional[str] = None @@ -50,7 +63,7 @@ class QdrantConfig(BaseModel): replication_factor: int = 1 write_consistency_factor: int = 1 - def to_qdrant_params(self): + def to_qdrant_params(self) -> PathQdrantParams | UrlQdrantParams: if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): @@ -58,23 +71,23 @@ class QdrantConfig(BaseModel): raise ValueError("Root path is not set") path = os.path.join(self.root_path, path) - return {"path": path} + return PathQdrantParams(path=path) else: - return { - "url": self.endpoint, - "api_key": self.api_key, - "timeout": self.timeout, - "verify": self.endpoint.startswith("https"), - "grpc_port": self.grpc_port, - "prefer_grpc": self.prefer_grpc, - } + return UrlQdrantParams( + url=self.endpoint, + api_key=self.api_key, + timeout=self.timeout, + verify=self.endpoint.startswith("https"), + grpc_port=self.grpc_port, + prefer_grpc=self.prefer_grpc, + ) class QdrantVector(BaseVector): def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"): super().__init__(collection_name) self._client_config = config - self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) + self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params().model_dump()) self._distance_func = distance_func.upper() self._group_id = group_id diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 4af34bbb2d..2485857070 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -24,10 +24,10 @@ logger = logging.getLogger(__name__) class TencentConfig(BaseModel): url: str - api_key: Optional[str] + api_key: Optional[str] = None timeout: float = 30 - username: Optional[str] - database: Optional[str] + username: Optional[str] = None + database: Optional[str] = None index_type: str = "HNSW" metric_type: str = "IP" shard: int = 1 diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 184b5f2142..e1d4422144 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -1,5 +1,6 @@ import time import uuid +from collections.abc import Sequence import requests from requests.auth import HTTPDigestAuth @@ -139,7 +140,7 @@ class TidbService: @staticmethod def batch_update_tidb_serverless_cluster_status( - tidb_serverless_list: list[TidbAuthBinding], + tidb_serverless_list: Sequence[TidbAuthBinding], project_id: str, api_url: str, iam_url: str, diff --git a/api/core/rag/extractor/entity/datasource_type.py b/api/core/rag/extractor/entity/datasource_type.py index 19ad300d11..6568f60ea2 100644 --- a/api/core/rag/extractor/entity/datasource_type.py +++ b/api/core/rag/extractor/entity/datasource_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class DatasourceType(Enum): +class DatasourceType(StrEnum): FILE = "upload_file" NOTION = "notion_import" WEBSITE = "website_crawl" diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 4ed8dfbbd8..5199208f70 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -23,7 +23,7 @@ class UnstructuredWordExtractor(BaseExtractor): unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) # check the file extension try: - import magic # noqa: F401 # pyright: ignore[reportUnusedImport] + import magic # noqa: F401 is_doc = detect_filetype(self._file_path) == FileType.DOC except ImportError: diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py index c8ad53e3dd..1d9ca89ba7 100644 --- a/api/core/rag/index_processor/constant/built_in_field.py +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -1,15 +1,15 @@ -from enum import Enum, StrEnum +from enum import StrEnum, auto class BuiltInField(StrEnum): - document_name = "document_name" - uploader = "uploader" - upload_date = "upload_date" - last_update_date = "last_update_date" - source = "source" + document_name = auto() + uploader = auto() + upload_date = auto() + last_update_date = auto() + source = auto() -class MetadataDataSource(Enum): +class MetadataDataSource(StrEnum): upload_file = "file_upload" website_crawl = "website" notion_import = "notion" diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index d1088af853..514596200f 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -113,21 +113,33 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # node_ids is segment's node_ids if dataset.indexing_technique == "high_quality": delete_child_chunks = kwargs.get("delete_child_chunks") or False + precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") vector = Vector(dataset) + if node_ids: - child_node_ids = ( - db.session.query(ChildChunk.index_node_id) - .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) - .where( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.index_node_id.in_(node_ids), - ChildChunk.dataset_id == dataset.id, + # Use precomputed child_node_ids if available (to avoid race conditions) + if precomputed_child_node_ids is not None: + child_node_ids = precomputed_child_node_ids + else: + # Fallback to original query (may fail if segments are already deleted) + child_node_ids = ( + db.session.query(ChildChunk.index_node_id) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() ) - .all() - ) - child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] - vector.delete_by_ids(child_node_ids) - if delete_child_chunks: + child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]] + + # Delete from vector index + if child_node_ids: + vector.delete_by_ids(child_node_ids) + + # Delete from database + if delete_child_chunks and child_node_ids: db.session.query(ChildChunk).where( ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) ).delete(synchronize_session=False) diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index c5b6ac4608..bfa14e03d6 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -110,9 +110,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): docs = [] current_doc: list[str] = [] total = 0 - index = 0 - for d in splits: - _len = lengths[index] + for d, _len in zip(splits, lengths): if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: if total > self._chunk_size: logger.warning( @@ -134,7 +132,6 @@ class TextSplitter(BaseDocumentTransformer, ABC): current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) - index += 1 doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index b36252dba2..95ad9f25fe 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -94,10 +94,10 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER # In-memory cache for workflow node executions - self._execution_cache: dict[str, WorkflowNodeExecution] = {} + self._execution_cache = {} # Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval - self._workflow_execution_mapping: dict[str, list[str]] = {} + self._workflow_execution_mapping = {} logger.info( "Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s", diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 68bfe5b4a5..45fd16d684 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -18,7 +18,7 @@ from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolProviderNotFoundError, ) -from core.tools.utils.yaml_utils import load_yaml_file +from core.tools.utils.yaml_utils import load_yaml_file_cached class BuiltinToolProviderController(ToolProviderController): @@ -31,7 +31,7 @@ class BuiltinToolProviderController(ToolProviderController): provider = self.__class__.__module__.split(".")[-1] yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml") try: - provider_yaml = load_yaml_file(yaml_path, ignore_error=False) + provider_yaml = load_yaml_file_cached(yaml_path) except Exception as e: raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}") @@ -71,7 +71,7 @@ class BuiltinToolProviderController(ToolProviderController): for tool_file in tool_files: # get tool name tool_name = tool_file.split(".")[0] - tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False) + tool = load_yaml_file_cached(path.join(tool_path, tool_file)) # get tool class, import the module assistant_tool_class: type = load_single_subclass_from_source( diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 5790aea2b0..0cc992155a 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -1,4 +1,5 @@ from pydantic import Field +from sqlalchemy import select from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool_provider import ToolProviderController @@ -176,11 +177,11 @@ class ApiToolProviderController(ToolProviderController): tools: list[ApiTool] = [] # get tenant api providers - db_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider) - .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) - .all() - ) + db_providers = db.session.scalars( + select(ApiToolProvider).where( + ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name + ) + ).all() if db_providers and len(db_providers) != 0: for db_provider in db_providers: diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 66304b30a5..f934b0bf96 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,8 +1,7 @@ import base64 import contextlib -import enum from collections.abc import Mapping -from enum import Enum +from enum import StrEnum, auto from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator @@ -22,37 +21,37 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY -class ToolLabelEnum(Enum): - SEARCH = "search" - IMAGE = "image" - VIDEOS = "videos" - WEATHER = "weather" - FINANCE = "finance" - DESIGN = "design" - TRAVEL = "travel" - SOCIAL = "social" - NEWS = "news" - MEDICAL = "medical" - PRODUCTIVITY = "productivity" - EDUCATION = "education" - BUSINESS = "business" - ENTERTAINMENT = "entertainment" - UTILITIES = "utilities" - OTHER = "other" +class ToolLabelEnum(StrEnum): + SEARCH = auto() + IMAGE = auto() + VIDEOS = auto() + WEATHER = auto() + FINANCE = auto() + DESIGN = auto() + TRAVEL = auto() + SOCIAL = auto() + NEWS = auto() + MEDICAL = auto() + PRODUCTIVITY = auto() + EDUCATION = auto() + BUSINESS = auto() + ENTERTAINMENT = auto() + UTILITIES = auto() + OTHER = auto() -class ToolProviderType(enum.StrEnum): +class ToolProviderType(StrEnum): """ Enum class for tool provider """ - PLUGIN = "plugin" + PLUGIN = auto() BUILT_IN = "builtin" - WORKFLOW = "workflow" - API = "api" - APP = "app" + WORKFLOW = auto() + API = auto() + APP = auto() DATASET_RETRIEVAL = "dataset-retrieval" - MCP = "mcp" + MCP = auto() @classmethod def value_of(cls, value: str) -> "ToolProviderType": @@ -68,15 +67,15 @@ class ToolProviderType(enum.StrEnum): raise ValueError(f"invalid mode value {value}") -class ApiProviderSchemaType(Enum): +class ApiProviderSchemaType(StrEnum): """ Enum class for api provider schema type. """ - OPENAPI = "openapi" - SWAGGER = "swagger" - OPENAI_PLUGIN = "openai_plugin" - OPENAI_ACTIONS = "openai_actions" + OPENAPI = auto() + SWAGGER = auto() + OPENAI_PLUGIN = auto() + OPENAI_ACTIONS = auto() @classmethod def value_of(cls, value: str) -> "ApiProviderSchemaType": @@ -92,14 +91,14 @@ class ApiProviderSchemaType(Enum): raise ValueError(f"invalid mode value {value}") -class ApiProviderAuthType(Enum): +class ApiProviderAuthType(StrEnum): """ Enum class for api provider auth type. """ - NONE = "none" - API_KEY_HEADER = "api_key_header" - API_KEY_QUERY = "api_key_query" + NONE = auto() + API_KEY_HEADER = auto() + API_KEY_QUERY = auto() @classmethod def value_of(cls, value: str) -> "ApiProviderAuthType": @@ -176,10 +175,10 @@ class ToolInvokeMessage(BaseModel): return value class LogMessage(BaseModel): - class LogStatus(Enum): - START = "start" - ERROR = "error" - SUCCESS = "success" + class LogStatus(StrEnum): + START = auto() + ERROR = auto() + SUCCESS = auto() id: str label: str = Field(..., description="The label of the log") @@ -193,19 +192,19 @@ class ToolInvokeMessage(BaseModel): retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") - class MessageType(Enum): - TEXT = "text" - IMAGE = "image" - LINK = "link" - BLOB = "blob" - JSON = "json" - IMAGE_LINK = "image_link" - BINARY_LINK = "binary_link" - VARIABLE = "variable" - FILE = "file" - LOG = "log" - BLOB_CHUNK = "blob_chunk" - RETRIEVER_RESOURCES = "retriever_resources" + class MessageType(StrEnum): + TEXT = auto() + IMAGE = auto() + LINK = auto() + BLOB = auto() + JSON = auto() + IMAGE_LINK = auto() + BINARY_LINK = auto() + VARIABLE = auto() + FILE = auto() + LOG = auto() + BLOB_CHUNK = auto() + RETRIEVER_RESOURCES = auto() type: MessageType = MessageType.TEXT """ @@ -250,29 +249,29 @@ class ToolParameter(PluginParameter): Overrides type """ - class ToolParameterType(enum.StrEnum): + class ToolParameterType(StrEnum): """ removes TOOLS_SELECTOR from PluginParameterType """ - STRING = PluginParameterType.STRING.value - NUMBER = PluginParameterType.NUMBER.value - BOOLEAN = PluginParameterType.BOOLEAN.value - SELECT = PluginParameterType.SELECT.value - SECRET_INPUT = PluginParameterType.SECRET_INPUT.value - FILE = PluginParameterType.FILE.value - FILES = PluginParameterType.FILES.value - APP_SELECTOR = PluginParameterType.APP_SELECTOR.value - MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value - ANY = PluginParameterType.ANY.value - DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value + STRING = PluginParameterType.STRING + NUMBER = PluginParameterType.NUMBER + BOOLEAN = PluginParameterType.BOOLEAN + SELECT = PluginParameterType.SELECT + SECRET_INPUT = PluginParameterType.SECRET_INPUT + FILE = PluginParameterType.FILE + FILES = PluginParameterType.FILES + APP_SELECTOR = PluginParameterType.APP_SELECTOR + MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR + ANY = PluginParameterType.ANY + DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT # MCP object and array type parameters - ARRAY = MCPServerParameterType.ARRAY.value - OBJECT = MCPServerParameterType.OBJECT.value + ARRAY = MCPServerParameterType.ARRAY + OBJECT = MCPServerParameterType.OBJECT # deprecated, should not use. - SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value + SYSTEM_FILES = PluginParameterType.SYSTEM_FILES def as_normal_type(self): return as_normal_type(self) @@ -280,10 +279,10 @@ class ToolParameter(PluginParameter): def cast_value(self, value: Any): return cast_parameter_value(self, value) - class ToolParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool - FORM = "form" # should be set before invoking tool - LLM = "llm" # will be set by LLM + class ToolParameterForm(StrEnum): + SCHEMA = auto() # should be set while adding tool + FORM = auto() # should be set before invoking tool + LLM = auto() # will be set by LLM type: ToolParameterType = Field(..., description="The type of the parameter") human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") @@ -446,14 +445,14 @@ class ToolLabel(BaseModel): icon: str = Field(..., description="The icon of the tool") -class ToolInvokeFrom(Enum): +class ToolInvokeFrom(StrEnum): """ Enum class for tool invoke """ - WORKFLOW = "workflow" - AGENT = "agent" - PLUGIN = "plugin" + WORKFLOW = auto() + AGENT = auto() + PLUGIN = auto() class ToolSelector(BaseModel): @@ -478,9 +477,9 @@ class ToolSelector(BaseModel): return self.model_dump() -class CredentialType(enum.StrEnum): +class CredentialType(StrEnum): API_KEY = "api-key" - OAUTH2 = "oauth2" + OAUTH2 = auto() def get_name(self): if self == CredentialType.API_KEY: diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 6810ac683d..21d256ae03 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -67,22 +67,42 @@ class MCPTool(Tool): for content in result.content: if isinstance(content, TextContent): - try: - content_json = json.loads(content.text) - if isinstance(content_json, dict): - yield self.create_json_message(content_json) - elif isinstance(content_json, list): - for item in content_json: - yield self.create_json_message(item) - else: - yield self.create_text_message(content.text) - except json.JSONDecodeError: - yield self.create_text_message(content.text) - + yield from self._process_text_content(content) elif isinstance(content, ImageContent): - yield self.create_blob_message( - blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType} - ) + yield self._process_image_content(content) + + def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]: + """Process text content and yield appropriate messages.""" + try: + content_json = json.loads(content.text) + yield from self._process_json_content(content_json) + except json.JSONDecodeError: + yield self.create_text_message(content.text) + + def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]: + """Process JSON content based on its type.""" + if isinstance(content_json, dict): + yield self.create_json_message(content_json) + elif isinstance(content_json, list): + yield from self._process_json_list(content_json) + else: + # For primitive types (str, int, bool, etc.), convert to string + yield self.create_text_message(str(content_json)) + + def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]: + """Process a list of JSON items.""" + if any(not isinstance(item, dict) for item in json_list): + # If the list contains any non-dict item, treat the entire list as a text message. + yield self.create_text_message(str(json_list)) + return + + # Otherwise, process each dictionary as a separate JSON message. + for item in json_list: + yield self.create_json_message(item) + + def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage: + """Process image content and return a blob message.""" + return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}) def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool": return MCPTool( diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 84b874975a..39646b7fc8 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -87,9 +87,7 @@ class ToolLabelManager: assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] - labels: list[ToolLabelBinding] = ( - db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all() - ) + labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index bc1f09a2fc..b29da3f0ba 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -667,9 +667,9 @@ class ToolManager: # get db api providers if "api" in filters: - db_api_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() - ) + db_api_providers = db.session.scalars( + select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id) + ).all() api_provider_controllers: list[dict[str, Any]] = [ {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} @@ -690,9 +690,9 @@ class ToolManager: if "workflow" in filters: # get workflow providers - workflow_providers: list[WorkflowToolProvider] = ( - db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() - ) + workflow_providers = db.session.scalars( + select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) + ).all() workflow_provider_controllers: list[WorkflowToolProviderController] = [] for workflow_provider in workflow_providers: diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 8a0a91a50c..e9b5dab7d3 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -1,4 +1,5 @@ import logging +from functools import lru_cache from pathlib import Path from typing import Any @@ -8,28 +9,25 @@ from yaml import YAMLError logger = logging.getLogger(__name__) -def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}): - """ - Safe loading a YAML file - :param file_path: the path of the YAML file - :param ignore_error: - if True, return default_value if error occurs and the error will be logged in debug level - if False, raise error if error occurs - :param default_value: the value returned when errors ignored - :return: an object of the YAML content - """ +def _load_yaml_file(*, file_path: str): if not file_path or not Path(file_path).exists(): - if ignore_error: - return default_value - else: - raise FileNotFoundError(f"File not found: {file_path}") + raise FileNotFoundError(f"File not found: {file_path}") with open(file_path, encoding="utf-8") as yaml_file: try: yaml_content = yaml.safe_load(yaml_file) - return yaml_content or default_value + return yaml_content except Exception as e: - if ignore_error: - return default_value - else: - raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e + + +@lru_cache(maxsize=128) +def load_yaml_file_cached(file_path: str) -> Any: + """ + Cached version of load_yaml_file for static configuration files. + Only use for files that don't change during runtime (e.g., position files) + + :param file_path: the path of the YAML file + :return: an object of the YAML content + """ + return _load_yaml_file(file_path=file_path) diff --git a/api/core/variables/segment_group.py b/api/core/variables/segment_group.py index b363255b2c..0a41b64228 100644 --- a/api/core/variables/segment_group.py +++ b/api/core/variables/segment_group.py @@ -4,7 +4,7 @@ from .types import SegmentType class SegmentGroup(Segment): value_type: SegmentType = SegmentType.GROUP - value: list[Segment] + value: list[Segment] = None # type: ignore @property def text(self): diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index cfef193633..28644b0169 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -19,7 +19,7 @@ class Segment(BaseModel): model_config = ConfigDict(frozen=True) value_type: SegmentType - value: Any + value: Any = None @field_validator("value_type") @classmethod @@ -74,12 +74,12 @@ class NoneSegment(Segment): class StringSegment(Segment): value_type: SegmentType = SegmentType.STRING - value: str + value: str = None # type: ignore class FloatSegment(Segment): value_type: SegmentType = SegmentType.FLOAT - value: float + value: float = None # type: ignore # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. # @@ -98,12 +98,12 @@ class FloatSegment(Segment): class IntegerSegment(Segment): value_type: SegmentType = SegmentType.INTEGER - value: int + value: int = None # type: ignore class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] + value: Mapping[str, Any] = None # type: ignore @property def text(self) -> str: @@ -136,7 +136,7 @@ class ArraySegment(Segment): class FileSegment(Segment): value_type: SegmentType = SegmentType.FILE - value: File + value: File = None # type: ignore @property def markdown(self) -> str: @@ -153,17 +153,17 @@ class FileSegment(Segment): class BooleanSegment(Segment): value_type: SegmentType = SegmentType.BOOLEAN - value: bool + value: bool = None # type: ignore class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] + value: Sequence[Any] = None # type: ignore class ArrayStringSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] + value: Sequence[str] = None # type: ignore @property def text(self) -> str: @@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment): class ArrayNumberSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] + value: Sequence[float | int] = None # type: ignore class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] + value: Sequence[Mapping[str, Any]] = None # type: ignore class ArrayFileSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] + value: Sequence[File] = None # type: ignore @property def markdown(self) -> str: @@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment): class ArrayBooleanSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] + value: Sequence[bool] = None # type: ignore def get_segment_discriminator(v: Any) -> SegmentType | None: diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index 594bb2b32e..63513bdc9f 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -3,6 +3,6 @@ from core.workflow.nodes.base import BaseNode class WorkflowNodeRunFailedError(Exception): def __init__(self, node: BaseNode, err_msg: str): - self._node = node - self._error = err_msg + self.node = node + self.error = err_msg super().__init__(f"Node {node.title} run failed: {err_msg}") diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index 54440df725..d010ce05b8 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime -from enum import Enum +from enum import StrEnum, auto from typing import Optional from pydantic import BaseModel, Field @@ -11,12 +11,12 @@ from libs.datetime_utils import naive_utc_now class RouteNodeState(BaseModel): - class Status(Enum): - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - PAUSED = "paused" - EXCEPTION = "exception" + class Status(StrEnum): + RUNNING = auto() + SUCCESS = auto() + FAILED = auto() + PAUSED = auto() + EXCEPTION = auto() id: str = Field(default_factory=lambda: str(uuid.uuid4())) """node state id""" diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 11b11068e7..ce6eb33ecc 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,4 +1,4 @@ -from enum import Enum, StrEnum +from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union from pydantic import BaseModel @@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData): agent_parameters: dict[str, AgentInput] -class ParamsAutoGenerated(Enum): - CLOSE = 0 - OPEN = 1 +class ParamsAutoGenerated(IntEnum): + CLOSE = auto() + OPEN = auto() class AgentOldVersionModelFeatures(StrEnum): @@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum): TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" - VISION = "vision" + VISION = auto() STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = "document" - VIDEO = "video" - AUDIO = "audio" + DOCUMENT = auto() + VIDEO = auto() + AUDIO = auto() diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index a05cc44c99..850ff14880 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum +from enum import StrEnum, auto from pydantic import BaseModel, Field @@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel): Generate Route Chunk. """ - class ChunkType(Enum): - VAR = "var" - TEXT = "text" + class ChunkType(StrEnum): + VAR = auto() + TEXT = auto() type: ChunkType = Field(..., description="generate route chunk type") diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 708da21177..90e45e9d25 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -23,7 +23,7 @@ NumberType = Union[int, float] class DefaultValue(BaseModel): - value: Any + value: Any = None type: DefaultValueType key: str diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index d357fea7dd..3a77541807 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -259,7 +259,7 @@ class KnowledgeRetrievalNode(BaseNode): ) all_documents = [] dataset_retrieval = DatasetRetrieval() - if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: + if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # fetch model config if node_data.single_retrieval_config is None: raise ValueError("single_retrieval_config is required") @@ -291,7 +291,7 @@ class KnowledgeRetrievalNode(BaseNode): metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) - elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: + elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index eb7b9fc2c6..cf46870254 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -67,8 +67,8 @@ class ListOperatorNode(BaseNode): return "1" def _run(self): - inputs: dict[str, list] = {} - process_data: dict[str, list] = {} + inputs: dict[str, Sequence[object]] = {} + process_data: dict[str, Sequence[object]] = {} outputs: dict[str, Any] = {} variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index c34a06d981..fdcdac1ec2 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1183,7 +1183,8 @@ def _combine_message_content_with_role( return AssistantPromptMessage(content=contents) case PromptMessageRole.SYSTEM: return SystemPromptMessage(content=contents) - raise NotImplementedError(f"Role {role} is not supported") + case _: + raise NotImplementedError(f"Role {role} is not supported") def _render_jinja2_message( diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py index 8caee27363..04a7323739 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -16,7 +16,7 @@ class UpdatedVariable(BaseModel): name: str selector: Sequence[str] value_type: SegmentType - new_value: Any + new_value: Any = None _T = TypeVar("_T", bound=MutableMapping[str, Any]) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index b69a9971b5..7aab4037c6 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -367,7 +367,7 @@ class WorkflowEntry: raise ValueError(f"Variable key {node_variable} not found in user inputs.") # environment variable already exist in variable pool, not from user inputs - if variable_pool.get(variable_selector): + if variable_pool.get(variable_selector) and variable_selector[0] == ENVIRONMENT_VARIABLE_NODE_ID: continue # fetch variable node id from variable selector diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index b8b5a89dc5..69959acd19 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from events.app_event import app_model_config_was_updated from extensions.ext_database import db from models.dataset import AppDatasetJoin @@ -13,7 +15,7 @@ def handle(sender, **kwargs): dataset_ids = get_dataset_ids_from_model_config(app_model_config) - app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() + app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all() removed_dataset_ids: set[str] = set() if not app_dataset_joins: diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index fcc3b63fa7..898ec1f153 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,5 +1,7 @@ from typing import cast +from sqlalchemy import select + from core.workflow.nodes import NodeType from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated @@ -15,7 +17,7 @@ def handle(sender, **kwargs): published_workflow = cast(Workflow, published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow) - app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() + app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all() removed_dataset_ids: set[str] = set() if not app_dataset_joins: diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index cd01a31068..5571c0d9ba 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -86,9 +86,7 @@ def load_user_from_request(request_from_flask_login): if not app_mcp_server: raise NotFound("App MCP server not found.") end_user = ( - db.session.query(EndUser) - .where(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp") - .first() + db.session.query(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").first() ) if not end_user: raise NotFound("End user not found.") diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index ef6b12fd59..43fe771bcd 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -9,19 +9,19 @@ import json import logging from dataclasses import asdict, dataclass from datetime import datetime -from enum import Enum +from enum import StrEnum, auto from typing import Any, Optional logger = logging.getLogger(__name__) -class FileStatus(Enum): +class FileStatus(StrEnum): """File status enumeration""" - ACTIVE = "active" # Active status - ARCHIVED = "archived" # Archived - DELETED = "deleted" # Deleted (soft delete) - BACKUP = "backup" # Backup file + ACTIVE = auto() # Active status + ARCHIVED = auto() # Archived + DELETED = auto() # Deleted (soft delete) + BACKUP = auto() # Backup file @dataclass diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index 243df92efe..2431b08d81 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -5,13 +5,13 @@ According to ClickZetta's permission model, different Volume types have differen """ import logging -from enum import Enum +from enum import StrEnum from typing import Optional logger = logging.getLogger(__name__) -class VolumePermission(Enum): +class VolumePermission(StrEnum): """Volume permission type enumeration""" READ = "SELECT" # Corresponds to ClickZetta's SELECT permission diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 9433b312cf..f2c37e1a4b 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -462,9 +462,9 @@ class StorageKeyLoader: upload_file_row = upload_files.get(model_id) if upload_file_row is None: raise ValueError(f"Upload file not found for id: {model_id}") - file._storage_key = upload_file_row.key + file.storage_key = upload_file_row.key elif file.transfer_method == FileTransferMethod.TOOL_FILE: tool_file_row = tool_files.get(model_id) if tool_file_row is None: raise ValueError(f"Tool file not found for id: {model_id}") - file._storage_key = tool_file_row.file_key + file.storage_key = tool_file_row.file_key diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index 8288bd54a3..b2b793d40e 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -12,4 +12,7 @@ def serialize_value_type(v: _VarTypedDict | Segment) -> str: if isinstance(v, Segment): return v.value_type.exposed_type().value else: - return v["value_type"].exposed_type().value + value_type = v.get("value_type") + if value_type is None: + raise ValueError("value_type is required but not provided") + return value_type.exposed_type().value diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index 3c039dff53..5258823a07 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -7,7 +7,7 @@ eliminates the need for repetitive language switching logic. """ from dataclasses import dataclass -from enum import Enum +from enum import StrEnum, auto from typing import Any, Optional, Protocol from flask import render_template @@ -17,26 +17,30 @@ from extensions.ext_mail import mail from services.feature_service import BrandingModel, FeatureService -class EmailType(Enum): +class EmailType(StrEnum): """Enumeration of supported email types.""" - RESET_PASSWORD = "reset_password" - INVITE_MEMBER = "invite_member" - EMAIL_CODE_LOGIN = "email_code_login" - CHANGE_EMAIL_OLD = "change_email_old" - CHANGE_EMAIL_NEW = "change_email_new" - CHANGE_EMAIL_COMPLETED = "change_email_completed" - OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm" - OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify" - OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify" - ACCOUNT_DELETION_SUCCESS = "account_deletion_success" - ACCOUNT_DELETION_VERIFICATION = "account_deletion_verification" - ENTERPRISE_CUSTOM = "enterprise_custom" - QUEUE_MONITOR_ALERT = "queue_monitor_alert" - DOCUMENT_CLEAN_NOTIFY = "document_clean_notify" + RESET_PASSWORD = auto() + RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = auto() + INVITE_MEMBER = auto() + EMAIL_CODE_LOGIN = auto() + CHANGE_EMAIL_OLD = auto() + CHANGE_EMAIL_NEW = auto() + CHANGE_EMAIL_COMPLETED = auto() + OWNER_TRANSFER_CONFIRM = auto() + OWNER_TRANSFER_OLD_NOTIFY = auto() + OWNER_TRANSFER_NEW_NOTIFY = auto() + ACCOUNT_DELETION_SUCCESS = auto() + ACCOUNT_DELETION_VERIFICATION = auto() + ENTERPRISE_CUSTOM = auto() + QUEUE_MONITOR_ALERT = auto() + DOCUMENT_CLEAN_NOTIFY = auto() + EMAIL_REGISTER = auto() + EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto() + RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto() -class EmailLanguage(Enum): +class EmailLanguage(StrEnum): """Supported email languages with fallback handling.""" EN_US = "en-US" @@ -441,6 +445,54 @@ def create_default_email_config() -> EmailI18nConfig: branded_template_path="clean_document_job_mail_template_zh-CN.html", ), }, + EmailType.EMAIL_REGISTER: { + EmailLanguage.EN_US: EmailTemplate( + subject="Register Your {application_title} Account", + template_path="register_email_template_en-US.html", + branded_template_path="without-brand/register_email_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="注册您的 {application_title} 账户", + template_path="register_email_template_zh-CN.html", + branded_template_path="without-brand/register_email_template_zh-CN.html", + ), + }, + EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST: { + EmailLanguage.EN_US: EmailTemplate( + subject="Register Your {application_title} Account", + template_path="register_email_when_account_exist_template_en-US.html", + branded_template_path="without-brand/register_email_when_account_exist_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="注册您的 {application_title} 账户", + template_path="register_email_when_account_exist_template_zh-CN.html", + branded_template_path="without-brand/register_email_when_account_exist_template_zh-CN.html", + ), + }, + EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST: { + EmailLanguage.EN_US: EmailTemplate( + subject="Reset Your {application_title} Password", + template_path="reset_password_mail_when_account_not_exist_template_en-US.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="重置您的 {application_title} 密码", + template_path="reset_password_mail_when_account_not_exist_template_zh-CN.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html", + ), + }, + EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER: { + EmailLanguage.EN_US: EmailTemplate( + subject="Reset Your {application_title} Password", + template_path="reset_password_mail_when_account_not_exist_no_register_template_en-US.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="重置您的 {application_title} 密码", + template_path="reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html", + ), + }, } return EmailI18nConfig(templates=templates) diff --git a/api/libs/external_api.py b/api/libs/external_api.py index cee80f7f24..cf91b0117f 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -69,6 +69,8 @@ def register_external_error_handlers(api: Api): headers["WWW-Authenticate"] = 'Bearer realm="api"' return data, status_code, headers + _ = handle_http_exception + @api.errorhandler(ValueError) def handle_value_error(e: ValueError): got_request_exception.send(current_app, exception=e) @@ -76,6 +78,8 @@ def register_external_error_handlers(api: Api): data = {"code": "invalid_param", "message": str(e), "status": status_code} return data, status_code + _ = handle_value_error + @api.errorhandler(AppInvokeQuotaExceededError) def handle_quota_exceeded(e: AppInvokeQuotaExceededError): got_request_exception.send(current_app, exception=e) @@ -83,15 +87,17 @@ def register_external_error_handlers(api: Api): data = {"code": "too_many_requests", "message": str(e), "status": status_code} return data, status_code + _ = handle_quota_exceeded + @api.errorhandler(Exception) def handle_general_exception(e: Exception): got_request_exception.send(current_app, exception=e) status_code = 500 - data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)}) + data = getattr(e, "data", {"message": http_status_message(status_code)}) # 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response) - if not isinstance(data, Mapping): + if not isinstance(data, dict): data = {"message": str(e)} data.setdefault("code", "unknown") @@ -101,10 +107,12 @@ def register_external_error_handlers(api: Api): exc_info: Any = sys.exc_info() if exc_info[1] is None: exc_info = None - current_app.log_exception(exc_info) # ty: ignore [invalid-argument-type] + current_app.log_exception(exc_info) return data, status_code + _ = handle_general_exception + class ExternalApi(Api): _authorizations = { diff --git a/api/libs/helper.py b/api/libs/helper.py index 139cb329de..09dbfac6cb 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -68,7 +68,7 @@ class AppIconUrlField(fields.Raw): if isinstance(obj, dict) and "app" in obj: obj = obj["app"] - if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value: + if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE: return file_helpers.get_signed_file_url(obj.icon) return None @@ -167,13 +167,6 @@ class DatetimeString: return value -def _get_float(value): - try: - return float(value) - except (TypeError, ValueError): - raise ValueError(f"{value} is not a valid float") - - def timezone(timezone_string): if timezone_string and timezone_string in available_timezones(): return timezone_string diff --git a/api/models/account.py b/api/models/account.py index 019159d2da..aefef673df 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -7,6 +7,7 @@ import sqlalchemy as sa from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor +from typing_extensions import deprecated from models.base import Base @@ -187,7 +188,28 @@ class Account(UserMixin, Base): return TenantAccountRole.is_admin_role(self.role) @property + @deprecated("Use has_edit_permission instead.") def is_editor(self): + """Determines if the account has edit permissions in their current tenant (workspace). + + This property checks if the current role has editing privileges, which includes: + - `OWNER` + - `ADMIN` + - `EDITOR` + + Note: This checks for any role with editing permission, not just the 'EDITOR' role specifically. + """ + return self.has_edit_permission + + @property + def has_edit_permission(self): + """Determines if the account has editing permissions in their current tenant (workspace). + + This property checks if the current role has editing privileges, which includes: + - `OWNER` + - `ADMIN` + - `EDITOR` + """ return TenantAccountRole.is_editing_role(self.role) @property @@ -218,10 +240,12 @@ class Tenant(Base): updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: - return ( - db.session.query(Account) - .where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) - .all() + return list( + db.session.scalars( + select(Account).where( + Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id + ) + ).all() ) @property diff --git a/api/models/dataset.py b/api/models/dataset.py index 07f3eb18db..7314945053 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -208,7 +208,9 @@ class Dataset(Base): @property def doc_metadata(self): - dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all() + dataset_metadatas = db.session.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id) + ).all() doc_metadata = [ { @@ -222,35 +224,35 @@ class Dataset(Base): doc_metadata.append( { "id": "built-in", - "name": BuiltInField.document_name.value, + "name": BuiltInField.document_name, "type": "string", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.uploader.value, + "name": BuiltInField.uploader, "type": "string", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.upload_date.value, + "name": BuiltInField.upload_date, "type": "time", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.last_update_date.value, + "name": BuiltInField.last_update_date, "type": "time", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.source.value, + "name": BuiltInField.source, "type": "string", } ) @@ -542,7 +544,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.source, "type": "string", - "value": MetadataDataSource[self.data_source_type].value, + "value": MetadataDataSource[self.data_source_type], } ) return built_in_fields @@ -1055,13 +1057,11 @@ class ExternalKnowledgeApis(Base): @property def dataset_bindings(self) -> list[dict[str, Any]]: - external_knowledge_bindings = ( - db.session.query(ExternalKnowledgeBindings) - .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) - .all() - ) + external_knowledge_bindings = db.session.scalars( + select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) + ).all() dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] - datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() + datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() dataset_bindings: list[dict[str, Any]] = [] for dataset in datasets: dataset_bindings.append({"id": dataset.id, "name": dataset.name}) diff --git a/api/models/model.py b/api/models/model.py index f8ead1f872..50c07268dd 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,7 +3,7 @@ import re import uuid from collections.abc import Mapping from datetime import datetime -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, Literal, Optional, cast from core.plugin.entities.plugin import GenericProviderID @@ -62,9 +62,9 @@ class AppMode(StrEnum): raise ValueError(f"invalid mode value {value}") -class IconType(Enum): - IMAGE = "image" - EMOJI = "emoji" +class IconType(StrEnum): + IMAGE = auto() + EMOJI = auto() class App(Base): @@ -149,15 +149,15 @@ class App(Base): if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get( "strategy", "" ) in {"function_call", "react"}: - self.mode = AppMode.AGENT_CHAT.value + self.mode = AppMode.AGENT_CHAT db.session.commit() return True return False @property def mode_compatible_with_agent(self) -> str: - if self.mode == AppMode.CHAT.value and self.is_agent: - return AppMode.AGENT_CHAT.value + if self.mode == AppMode.CHAT and self.is_agent: + return AppMode.AGENT_CHAT return str(self.mode) @@ -713,7 +713,7 @@ class Conversation(Base): model_config = {} app_model_config: Optional[AppModelConfig] = None - if self.mode == AppMode.ADVANCED_CHAT.value: + if self.mode == AppMode.ADVANCED_CHAT: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) model_config = override_model_configs @@ -812,7 +812,7 @@ class Conversation(Base): @property def status_count(self): - messages = db.session.query(Message).where(Message.conversation_id == self.id).all() + messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all() status_counts = { WorkflowExecutionStatus.RUNNING: 0, WorkflowExecutionStatus.SUCCEEDED: 0, @@ -1090,7 +1090,7 @@ class Message(Base): @property def feedbacks(self): - feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all() + feedbacks = db.session.scalars(select(MessageFeedback).where(MessageFeedback.message_id == self.id)).all() return feedbacks @property @@ -1145,7 +1145,7 @@ class Message(Base): def message_files(self) -> list[dict[str, Any]]: from factories import file_factory - message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() + message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all() current_app = db.session.query(App).where(App.id == self.app_id).first() if not current_app: raise ValueError(f"App {self.app_id} not found") @@ -1459,6 +1459,14 @@ class OperationLog(Base): updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) +class DefaultEndUserSessionID(StrEnum): + """ + End User Session ID enum. + """ + + DEFAULT_SESSION_ID = "DEFAULT-USER" + + class EndUser(Base, UserMixin): __tablename__ = "end_users" __table_args__ = ( diff --git a/api/models/provider.py b/api/models/provider.py index 9a344ea56d..17094f6d6e 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum, auto from functools import cached_property from typing import Optional @@ -12,9 +12,9 @@ from .engine import db from .types import StringUUID -class ProviderType(Enum): - CUSTOM = "custom" - SYSTEM = "system" +class ProviderType(StrEnum): + CUSTOM = auto() + SYSTEM = auto() @staticmethod def value_of(value: str) -> "ProviderType": @@ -24,14 +24,14 @@ class ProviderType(Enum): raise ValueError(f"No matching enum found for value '{value}'") -class ProviderQuotaType(Enum): - PAID = "paid" +class ProviderQuotaType(StrEnum): + PAID = auto() """hosted paid quota""" - FREE = "free" + FREE = auto() """third-party free quota""" - TRIAL = "trial" + TRIAL = auto() """hosted trial quota""" @staticmethod diff --git a/api/models/workflow.py b/api/models/workflow.py index 4686b38b01..78582dd3fb 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,7 +2,7 @@ import json import logging from collections.abc import Mapping, Sequence from datetime import datetime -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, Optional, Union, cast from uuid import uuid4 @@ -41,13 +41,13 @@ from .types import EnumText, StringUUID logger = logging.getLogger(__name__) -class WorkflowType(Enum): +class WorkflowType(StrEnum): """ Workflow Type Enum """ - WORKFLOW = "workflow" - CHAT = "chat" + WORKFLOW = auto() + CHAT = auto() @classmethod def value_of(cls, value: str) -> "WorkflowType": @@ -777,7 +777,7 @@ class WorkflowNodeExecutionModel(Base): return extras -class WorkflowAppLogCreatedFrom(Enum): +class WorkflowAppLogCreatedFrom(StrEnum): """ Workflow App Log Created From Enum """ diff --git a/api/pyproject.toml b/api/pyproject.toml index c59140e246..b0a02589c7 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -168,6 +168,8 @@ dev = [ "types-redis>=4.6.0.20241004", "celery-types>=0.23.0", "mypy~=1.17.1", + "locust>=2.40.4", + "sseclient-py>=1.8.0", ] ############################################################ diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index a3a5f2044e..7c59c2ca28 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -1,24 +1,44 @@ { - "include": ["models", "configs"], - "exclude": [".venv", "tests/", "migrations/"], - "ignore": [ - "core/", - "controllers/", - "tasks/", - "services/", - "schedule/", - "extensions/", - "utils/", - "repositories/", - "libs/", - "fields/", - "factories/", - "events/", - "contexts/", - "constants/", - "commands.py" + "include": ["."], + "exclude": [ + ".venv", + "tests/", + "migrations/", + "core/rag", + "extensions", + "libs", + "controllers/console/datasets", + "controllers/service_api/dataset", + "core/ops", + "core/tools", + "core/model_runtime", + "core/workflow", + "core/app/app_config/easy_ui_based_app/dataset" ], "typeCheckingMode": "strict", + "allowedUntypedLibraries": [ + "flask_restx", + "flask_login", + "opentelemetry.instrumentation.celery", + "opentelemetry.instrumentation.flask", + "opentelemetry.instrumentation.requests", + "opentelemetry.instrumentation.sqlalchemy", + "opentelemetry.instrumentation.redis" + ], + "reportUnknownMemberType": "hint", + "reportUnknownParameterType": "hint", + "reportUnknownArgumentType": "hint", + "reportUnknownVariableType": "hint", + "reportUnknownLambdaType": "hint", + "reportMissingParameterType": "hint", + "reportMissingTypeArgument": "hint", + "reportUnnecessaryContains": "hint", + "reportUnnecessaryComparison": "hint", + "reportUnnecessaryCast": "hint", + "reportUnnecessaryIsInstance": "hint", + "reportUntypedFunctionDecorator": "hint", + + "reportAttributeAccessIssue": "hint", "pythonVersion": "3.11", "pythonPlatform": "All" } diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 63e6132b6a..2b1e6b47cc 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -96,11 +96,11 @@ def clean_unused_datasets_task(): break for dataset in datasets: - dataset_query = ( - db.session.query(DatasetQuery) - .where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id) - .all() - ) + dataset_query = db.session.scalars( + select(DatasetQuery).where( + DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id + ) + ).all() if not dataset_query or len(dataset_query) == 0: try: @@ -121,15 +121,13 @@ def clean_unused_datasets_task(): if should_clean: # Add auto disable log if required if add_logs: - documents = ( - db.session.query(Document) - .where( + documents = db.session.scalars( + select(Document).where( Document.dataset_id == dataset.id, Document.enabled == True, Document.archived == False, ) - .all() - ) + ).all() for document in documents: dataset_auto_disable_log = DatasetAutoDisableLog( tenant_id=dataset.tenant_id, diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 9e32ecc716..ef6edd6709 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -3,6 +3,7 @@ import time from collections import defaultdict import click +from sqlalchemy import select import app from configs import dify_config @@ -31,9 +32,9 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: - dataset_auto_disable_logs = ( - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all() - ) + dataset_auto_disable_logs = db.session.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) + ).all() # group by tenant_id dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) for dataset_auto_disable_log in dataset_auto_disable_logs: diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 1bfeb869e2..1befa0e8b5 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -1,6 +1,8 @@ import time +from collections.abc import Sequence import click +from sqlalchemy import select import app from configs import dify_config @@ -15,11 +17,9 @@ def update_tidb_serverless_status_task(): start_at = time.perf_counter() try: # check the number of idle tidb serverless - tidb_serverless_list = ( - db.session.query(TidbAuthBinding) - .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") - .all() - ) + tidb_serverless_list = db.session.scalars( + select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + ).all() if len(tidb_serverless_list) == 0: return # update tidb serverless status @@ -32,7 +32,7 @@ def update_tidb_serverless_status_task(): click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green")) -def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): +def update_clusters(tidb_serverless_list: Sequence[TidbAuthBinding]): try: # batch 20 for i in range(0, len(tidb_serverless_list), 20): diff --git a/api/services/account_service.py b/api/services/account_service.py index a76792f88e..b10887e88e 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -37,7 +37,6 @@ from services.billing_service import BillingService from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, - AccountNotFoundError, AccountNotLinkTenantError, AccountPasswordError, AccountRegisterError, @@ -65,7 +64,11 @@ from tasks.mail_owner_transfer_task import ( send_old_owner_transfer_notify_email_task, send_owner_transfer_confirm_task, ) -from tasks.mail_reset_password_task import send_reset_password_mail_task +from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist +from tasks.mail_reset_password_task import ( + send_reset_password_mail_task, + send_reset_password_mail_task_when_account_not_exist, +) logger = logging.getLogger(__name__) @@ -82,8 +85,9 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) class AccountService: reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) + email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1) email_code_login_rate_limiter = RateLimiter( - prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1 + prefix="email_code_login_rate_limit", max_attempts=3, time_window=300 * 1 ) email_code_account_deletion_rate_limiter = RateLimiter( prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1 @@ -95,6 +99,7 @@ class AccountService: FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5 CHANGE_EMAIL_MAX_ERROR_LIMITS = 5 OWNER_TRANSFER_MAX_ERROR_LIMITS = 5 + EMAIL_REGISTER_MAX_ERROR_LIMITS = 5 @staticmethod def _get_refresh_token_key(refresh_token: str) -> str: @@ -171,7 +176,7 @@ class AccountService: account = db.session.query(Account).filter_by(email=email).first() if not account: - raise AccountNotFoundError() + raise AccountPasswordError("Invalid email or password.") if account.status == AccountStatus.BANNED.value: raise AccountLoginError("Account is banned.") @@ -246,6 +251,8 @@ class AccountService: account.name = name if password: + valid_password(password) + # generate password salt salt = secrets.token_bytes(16) base64_salt = base64.b64encode(salt).decode() @@ -294,7 +301,9 @@ class AccountService: if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError - raise EmailCodeAccountDeletionRateLimitExceededError() + raise EmailCodeAccountDeletionRateLimitExceededError( + int(cls.email_code_account_deletion_rate_limiter.time_window / 60) + ) send_account_deletion_verification_code.delay(to=email, code=code) @@ -433,6 +442,7 @@ class AccountService: account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US", + is_allow_register: bool = False, ): account_email = account.email if account else email if account_email is None: @@ -441,18 +451,59 @@ class AccountService: if cls.reset_password_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import PasswordResetRateLimitExceededError - raise PasswordResetRateLimitExceededError() + raise PasswordResetRateLimitExceededError(int(cls.reset_password_rate_limiter.time_window / 60)) code, token = cls.generate_reset_password_token(account_email, account) - send_reset_password_mail_task.delay( - language=language, - to=account_email, - code=code, - ) + if account: + send_reset_password_mail_task.delay( + language=language, + to=account_email, + code=code, + ) + else: + send_reset_password_mail_task_when_account_not_exist.delay( + language=language, + to=account_email, + is_allow_register=is_allow_register, + ) cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token + @classmethod + def send_email_register_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: str = "en-US", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + if cls.email_register_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import EmailRegisterRateLimitExceededError + + raise EmailRegisterRateLimitExceededError(int(cls.email_register_rate_limiter.time_window / 60)) + + code, token = cls.generate_email_register_token(account_email) + + if account: + send_email_register_mail_task_when_account_exist.delay( + language=language, + to=account_email, + account_name=account.name, + ) + + else: + send_email_register_mail_task.delay( + language=language, + to=account_email, + code=code, + ) + cls.email_register_rate_limiter.increment_rate_limit(account_email) + return token + @classmethod def send_change_email_email( cls, @@ -471,7 +522,7 @@ class AccountService: if cls.change_email_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import EmailChangeRateLimitExceededError - raise EmailChangeRateLimitExceededError() + raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60)) code, token = cls.generate_change_email_token(account_email, account, old_email=old_email) @@ -515,7 +566,7 @@ class AccountService: if cls.owner_transfer_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import OwnerTransferRateLimitExceededError - raise OwnerTransferRateLimitExceededError() + raise OwnerTransferRateLimitExceededError(int(cls.owner_transfer_rate_limiter.time_window / 60)) code, token = cls.generate_owner_transfer_token(account_email, account) workspace_name = workspace_name or "" @@ -585,6 +636,19 @@ class AccountService: ) return code, token + @classmethod + def generate_email_register_token( + cls, + email: str, + code: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) + additional_data["code"] = code + token = TokenManager.generate_token(email=email, token_type="email_register", additional_data=additional_data) + return code, token + @classmethod def generate_change_email_token( cls, @@ -623,6 +687,10 @@ class AccountService: def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password") + @classmethod + def revoke_email_register_token(cls, token: str): + TokenManager.revoke_token(token, "email_register") + @classmethod def revoke_change_email_token(cls, token: str): TokenManager.revoke_token(token, "change_email") @@ -635,6 +703,10 @@ class AccountService: def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: return TokenManager.get_token_data(token, "reset_password") + @classmethod + def get_email_register_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "email_register") + @classmethod def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]: return TokenManager.get_token_data(token, "change_email") @@ -656,7 +728,7 @@ class AccountService: if cls.email_code_login_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError - raise EmailCodeLoginRateLimitExceededError() + raise EmailCodeLoginRateLimitExceededError(int(cls.email_code_login_rate_limiter.time_window / 60)) code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) token = TokenManager.generate_token( @@ -742,6 +814,16 @@ class AccountService: count = int(count) + 1 redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count) + @staticmethod + @redis_fallback(default_return=None) + def add_email_register_error_rate_limit(email: str) -> None: + key = f"email_register_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, dify_config.EMAIL_REGISTER_LOCKOUT_DURATION, count) + @staticmethod @redis_fallback(default_return=False) def is_forgot_password_error_rate_limit(email: str) -> bool: @@ -761,6 +843,24 @@ class AccountService: key = f"forgot_password_error_rate_limit:{email}" redis_client.delete(key) + @staticmethod + @redis_fallback(default_return=False) + def is_email_register_error_rate_limit(email: str) -> bool: + key = f"email_register_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + count = int(count) + if count > AccountService.EMAIL_REGISTER_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + @redis_fallback(default_return=None) + def reset_email_register_error_rate_limit(email: str): + key = f"email_register_error_rate_limit:{email}" + redis_client.delete(key) + @staticmethod @redis_fallback(default_return=None) def add_change_email_error_rate_limit(email: str): @@ -1318,7 +1418,7 @@ class RegisterService: def get_invitation_if_token_valid( cls, workspace_id: Optional[str], email: str, token: str ) -> Optional[dict[str, Any]]: - invitation_data = cls._get_invitation_by_token(token, workspace_id, email) + invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -1355,7 +1455,7 @@ class RegisterService: } @classmethod - def _get_invitation_by_token( + def get_invitation_by_token( cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None ) -> Optional[dict[str, str]]: if workspace_id is not None and email is not None: diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 6f0ab2546a..f2ffa3b170 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -32,14 +32,14 @@ class AdvancedPromptTemplateService: def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str): context_prompt = copy.deepcopy(CONTEXT) - if app_mode == AppMode.CHAT.value: + if app_mode == AppMode.CHAT: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt ) elif model_mode == "chat": return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - elif app_mode == AppMode.COMPLETION.value: + elif app_mode == AppMode.COMPLETION: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt @@ -73,7 +73,7 @@ class AdvancedPromptTemplateService: def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str): baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) - if app_mode == AppMode.CHAT.value: + if app_mode == AppMode.CHAT: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt @@ -82,7 +82,7 @@ class AdvancedPromptTemplateService: return cls.get_chat_prompt( copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt ) - elif app_mode == AppMode.COMPLETION.value: + elif app_mode == AppMode.COMPLETION: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index ba86a31240..34681ba111 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -263,11 +263,9 @@ class AppAnnotationService: db.session.delete(annotation) - annotation_hit_histories = ( - db.session.query(AppAnnotationHitHistory) - .where(AppAnnotationHitHistory.annotation_id == annotation_id) - .all() - ) + annotation_hit_histories = db.session.scalars( + select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id) + ).all() if annotation_hit_histories: for annotation_hit_history in annotation_hit_histories: db.session.delete(annotation_hit_history) @@ -349,7 +347,7 @@ class AppAnnotationService: try: # Skip the first row - df = pd.read_csv(file, dtype=str) + df = pd.read_csv(file.stream, dtype=str) result = [] for _, row in df.iterrows(): content = {"question": row.iloc[0], "answer": row.iloc[1]} @@ -463,15 +461,23 @@ class AppAnnotationService: annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail - return { - "id": annotation_setting.id, - "enabled": True, - "score_threshold": annotation_setting.score_threshold, - "embedding_model": { - "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name, - }, - } + if collection_binding_detail: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } + else: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": {}, + } return {"enabled": False} @classmethod @@ -506,15 +512,23 @@ class AppAnnotationService: collection_binding_detail = annotation_setting.collection_binding_detail - return { - "id": annotation_setting.id, - "enabled": True, - "score_threshold": annotation_setting.score_threshold, - "embedding_model": { - "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name, - }, - } + if collection_binding_detail: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } + else: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": {}, + } @classmethod def clear_all_annotations(cls, app_id: str): diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 2ed73ffec1..49ff28d191 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -99,17 +99,17 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus: class PendingData(BaseModel): import_mode: str yaml_content: str - name: str | None - description: str | None - icon_type: str | None - icon: str | None - icon_background: str | None - app_id: str | None + name: str | None = None + description: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + app_id: str | None = None class CheckDependenciesPendingData(BaseModel): dependencies: list[PluginDependency] - app_id: str | None + app_id: str | None = None class AppDslService: diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index e812fcc992..c1ef206c99 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -60,7 +60,7 @@ class AppGenerateService: request_id = RateLimit.gen_request_key() try: request_id = rate_limit.enter(request_id) - if app_model.mode == AppMode.COMPLETION.value: + if app_model.mode == AppMode.COMPLETION: return rate_limit.generate( CompletionAppGenerator.convert_to_event_stream( CompletionAppGenerator().generate( @@ -69,7 +69,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: return rate_limit.generate( AgentChatAppGenerator.convert_to_event_stream( AgentChatAppGenerator().generate( @@ -78,7 +78,7 @@ class AppGenerateService: ), request_id, ) - elif app_model.mode == AppMode.CHAT.value: + elif app_model.mode == AppMode.CHAT: return rate_limit.generate( ChatAppGenerator.convert_to_event_stream( ChatAppGenerator().generate( @@ -87,7 +87,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.ADVANCED_CHAT.value: + elif app_model.mode == AppMode.ADVANCED_CHAT: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( @@ -103,7 +103,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( @@ -155,14 +155,14 @@ class AppGenerateService: @classmethod def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( AdvancedChatAppGenerator().single_iteration_generate( app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( WorkflowAppGenerator().single_iteration_generate( @@ -174,14 +174,14 @@ class AppGenerateService: @classmethod def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( AdvancedChatAppGenerator().single_loop_generate( app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( WorkflowAppGenerator().single_loop_generate( diff --git a/api/services/app_service.py b/api/services/app_service.py index 9b200a570d..c553ec8c19 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -40,15 +40,15 @@ class AppService: filters = [App.tenant_id == tenant_id, App.is_universal == False] if args["mode"] == "workflow": - filters.append(App.mode == AppMode.WORKFLOW.value) + filters.append(App.mode == AppMode.WORKFLOW) elif args["mode"] == "completion": - filters.append(App.mode == AppMode.COMPLETION.value) + filters.append(App.mode == AppMode.COMPLETION) elif args["mode"] == "chat": - filters.append(App.mode == AppMode.CHAT.value) + filters.append(App.mode == AppMode.CHAT) elif args["mode"] == "advanced-chat": - filters.append(App.mode == AppMode.ADVANCED_CHAT.value) + filters.append(App.mode == AppMode.ADVANCED_CHAT) elif args["mode"] == "agent-chat": - filters.append(App.mode == AppMode.AGENT_CHAT.value) + filters.append(App.mode == AppMode.AGENT_CHAT) if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) @@ -171,7 +171,7 @@ class AppService: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None # get original app model config - if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: + if app.mode == AppMode.AGENT_CHAT or app.is_agent: model_config = app.app_model_config if not model_config: return app diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 9b1999d813..a7cd5e9487 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) class AudioService: @classmethod def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise ValueError("Speech to text is not enabled") @@ -88,7 +88,7 @@ class AudioService: def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False): with app.app_context(): if voice is None: - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if is_draft: workflow = WorkflowService().get_draft_workflow(app_model=app_model) else: diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index f6e960b413..055cf65816 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -1,5 +1,7 @@ import json +from sqlalchemy import select + from core.helper import encrypter from extensions.ext_database import db from models.source import DataSourceApiKeyAuthBinding @@ -9,11 +11,11 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory class ApiKeyAuthService: @staticmethod def get_provider_auth_list(tenant_id: str): - data_source_api_key_bindings = ( - db.session.query(DataSourceApiKeyAuthBinding) - .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) - .all() - ) + data_source_api_key_bindings = db.session.scalars( + select(DataSourceApiKeyAuthBinding).where( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False) + ) + ).all() return data_source_api_key_bindings @staticmethod diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 2f1b63664f..f8f89d7428 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app +from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -115,7 +116,7 @@ class ClearFreePlanTenantExpiredLogs: @classmethod def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): with flask_app.app_context(): - apps = db.session.query(App).where(App.tenant_id == tenant_id).all() + apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all() app_ids = [app.id for app in apps] while True: with Session(db.engine).no_autoflush as session: @@ -407,6 +408,7 @@ class ClearFreePlanTenantExpiredLogs: datetime.timedelta(hours=1), ] + tenant_count = 0 for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index d017ce54ab..a29204aabf 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -222,8 +222,8 @@ class ConversationService: # Filter for variables created after the last_id stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at) - # Apply limit to query - query_stmt = stmt.limit(limit) # Get one extra to check if there are more + # Apply limit to query: fetch one extra row to determine has_more + query_stmt = stmt.limit(limit + 1) rows = session.scalars(query_stmt).all() has_more = False diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 2b151f9a8e..8c8b368d42 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,6 +6,7 @@ import secrets import time import uuid from collections import Counter +from collections.abc import Sequence from typing import Any, Literal, Optional import sqlalchemy as sa @@ -134,11 +135,14 @@ class DatasetService: # Check if tag_ids is not empty to avoid WHERE false condition if tag_ids and len(tag_ids) > 0: - target_ids = TagService.get_target_ids_by_tag_ids( - "knowledge", - tenant_id, # ty: ignore [invalid-argument-type] - tag_ids, - ) + if tenant_id is not None: + target_ids = TagService.get_target_ids_by_tag_ids( + "knowledge", + tenant_id, + tag_ids, + ) + else: + target_ids = [] if target_ids and len(target_ids) > 0: query = query.where(Dataset.id.in_(target_ids)) else: @@ -217,7 +221,7 @@ class DatasetService: and retrieval_model.reranking_model.reranking_model_name ): # check if reranking model setting is valid - DatasetService.check_embedding_model_setting( + DatasetService.check_reranking_model_setting( tenant_id, retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, @@ -738,14 +742,12 @@ class DatasetService: } # get recent 30 days auto disable logs start_date = datetime.datetime.now() - datetime.timedelta(days=30) - dataset_auto_disable_logs = ( - db.session.query(DatasetAutoDisableLog) - .where( + dataset_auto_disable_logs = db.session.scalars( + select(DatasetAutoDisableLog).where( DatasetAutoDisableLog.dataset_id == dataset_id, DatasetAutoDisableLog.created_at >= start_date, ) - .all() - ) + ).all() if dataset_auto_disable_logs: return { "document_ids": [log.document_id for log in dataset_auto_disable_logs], @@ -882,69 +884,58 @@ class DocumentService: return document @staticmethod - def get_document_by_ids(document_ids: list[str]) -> list[Document]: - documents = ( - db.session.query(Document) - .where( + def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where( Document.id.in_(document_ids), Document.enabled == True, Document.indexing_status == "completed", Document.archived == False, ) - .all() - ) + ).all() return documents @staticmethod - def get_document_by_dataset_id(dataset_id: str) -> list[Document]: - documents = ( - db.session.query(Document) - .where( + def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where( Document.dataset_id == dataset_id, Document.enabled == True, ) - .all() - ) + ).all() return documents @staticmethod - def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = ( - db.session.query(Document) - .where( + def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where( Document.dataset_id == dataset_id, Document.enabled == True, Document.indexing_status == "completed", Document.archived == False, ) - .all() - ) + ).all() return documents @staticmethod - def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = ( - db.session.query(Document) - .where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) - .all() - ) + def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + ).all() return documents @staticmethod - def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: + def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]: assert isinstance(current_user, Account) - - documents = ( - db.session.query(Document) - .where( + documents = db.session.scalars( + select(Document).where( Document.batch == batch, Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id, ) - .all() - ) + ).all() return documents @@ -981,13 +972,14 @@ class DocumentService: # Check if document_ids is not empty to avoid WHERE false condition if not document_ids or len(document_ids) == 0: return - documents = db.session.query(Document).where(Document.id.in_(document_ids)).all() + documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all() file_ids = [ document.data_source_info_dict["upload_file_id"] for document in documents if document.data_source_type == "upload_file" and document.data_source_info_dict ] - batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) + if dataset.doc_form is not None: + batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) for document in documents: db.session.delete(document) @@ -1012,7 +1004,7 @@ class DocumentService: if dataset.built_in_field_enabled: if document.doc_metadata: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata[BuiltInField.document_name.value] = name + doc_metadata[BuiltInField.document_name] = name document.doc_metadata = doc_metadata document.name = name @@ -2373,7 +2365,22 @@ class SegmentService: if segment.enabled: # send delete segment index task redis_client.setex(indexing_cache_key, 600, 1) - delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) + + # Get child chunk IDs before parent segment is deleted + child_node_ids = [] + if segment.index_node_id: + child_chunks = ( + db.session.query(ChildChunk.index_node_id) + .where( + ChildChunk.segment_id == segment.id, + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] + + delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids) + db.session.delete(segment) # update document word count assert document.word_count is not None @@ -2383,9 +2390,13 @@ class SegmentService: @classmethod def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): - assert isinstance(current_user, Account) - segments = ( - db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) + assert current_user is not None + # Check if segment_ids is not empty to avoid WHERE false condition + if not segment_ids or len(segment_ids) == 0: + return + segments_info = ( + db.session.query(DocumentSegment) + .with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count) .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, @@ -2395,18 +2406,36 @@ class SegmentService: .all() ) - if not segments: + if not segments_info: return - index_node_ids = [seg.index_node_id for seg in segments] - total_words = sum(seg.word_count for seg in segments) + index_node_ids = [info[0] for info in segments_info] + segment_db_ids = [info[1] for info in segments_info] + total_words = sum(info[2] for info in segments_info if info[2] is not None) + + # Get child chunk IDs before parent segments are deleted + child_node_ids = [] + if index_node_ids: + child_chunks = ( + db.session.query(ChildChunk.index_node_id) + .where( + ChildChunk.segment_id.in_(segment_db_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] + + # Start async cleanup with both parent and child node IDs + if index_node_ids or child_node_ids: + delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids) document.word_count = ( document.word_count - total_words if document.word_count and document.word_count > total_words else 0 ) db.session.add(document) - delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) + # Delete database records db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete() db.session.commit() @@ -2420,16 +2449,14 @@ class SegmentService: if not segment_ids or len(segment_ids) == 0: return if action == "enable": - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, DocumentSegment.enabled == False, ) - .all() - ) + ).all() if not segments: return real_deal_segment_ids = [] @@ -2447,16 +2474,14 @@ class SegmentService: enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) elif action == "disable": - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, DocumentSegment.enabled == True, ) - .all() - ) + ).all() if not segments: return real_deal_segment_ids = [] @@ -2528,16 +2553,13 @@ class SegmentService: dataset: Dataset, ) -> list[ChildChunk]: assert isinstance(current_user, Account) - - child_chunks = ( - db.session.query(ChildChunk) - .where( + child_chunks = db.session.scalars( + select(ChildChunk).where( ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, ChildChunk.segment_id == segment.id, ) - .all() - ) + ).all() child_chunks_map = {chunk.id: chunk for chunk in child_chunks} new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] @@ -2688,56 +2710,6 @@ class SegmentService: return paginated_segments.items, paginated_segments.total - @classmethod - def update_segment_by_id( - cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str - ) -> tuple[DocumentSegment, Document]: - """Update a segment by its ID with validation and checks.""" - # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: - raise NotFound("Dataset not found.") - - # check user's model setting - DatasetService.check_dataset_model_setting(dataset) - - # check document - document = DocumentService.get_document(dataset_id, document_id) - if not document: - raise NotFound("Document not found.") - - # check embedding model setting if high quality - if dataset.indexing_technique == "high_quality": - try: - model_manager = ModelManager() - model_manager.get_model_instance( - tenant_id=user_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) - - # check segment - segment = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) - .first() - ) - if not segment: - raise NotFound("Segment not found.") - - # validate and update segment - cls.segment_create_args_validate(segment_data, document) - updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset) - - return updated_segment, document - @classmethod def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: """Get a segment by its ID.""" @@ -2797,19 +2769,13 @@ class DatasetCollectionBindingService: class DatasetPermissionService: @classmethod def get_dataset_partial_member_list(cls, dataset_id): - user_list_query = ( - db.session.query( + user_list_query = db.session.scalars( + select( DatasetPermission.account_id, - ) - .where(DatasetPermission.dataset_id == dataset_id) - .all() - ) + ).where(DatasetPermission.dataset_id == dataset_id) + ).all() - user_list = [] - for user in user_list_query: - user_list.append(user.account_id) - - return user_list + return user_list_query @classmethod def update_partial_member_list(cls, tenant_id, dataset_id, user_list): diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py index 5b6d06d3e6..93a23f6d08 100644 --- a/api/services/enterprise/plugin_manager_service.py +++ b/api/services/enterprise/plugin_manager_service.py @@ -6,6 +6,7 @@ from pydantic import BaseModel from services.enterprise.base import EnterprisePluginManagerRequest from services.errors.base import BaseServiceError +logger = logging.getLogger(__name__) class PluginCredentialType(enum.Enum): MODEL = 0 # must be 0 for API contract compatibility @@ -47,6 +48,9 @@ class PluginManagerService: if not ret.get("result", False): raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials") - logging.debug( - f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}" + logger.debug( + "Credential policy compliance checked for %s with credential %s, result: %s", + body.provider, + body.dify_credential_id, + ret.get("result", False), ) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 3262a00663..3911b763b6 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -181,7 +181,7 @@ class ExternalDatasetService: do http request depending on api bundle """ - kwargs = { + kwargs: dict[str, Any] = { "url": settings.url, "headers": settings.headers, "follow_redirects": True, diff --git a/api/services/file_service.py b/api/services/file_service.py index 8a4655d25e..364a872a91 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,7 +1,7 @@ import hashlib import os import uuid -from typing import Any, Literal, Union +from typing import Literal, Union from werkzeug.exceptions import NotFound @@ -35,7 +35,7 @@ class FileService: filename: str, content: bytes, mimetype: str, - user: Union[Account, EndUser, Any], + user: Union[Account, EndUser], source: Literal["datasets"] | None = None, source_url: str = "", ) -> UploadFile: diff --git a/api/services/message_service.py b/api/services/message_service.py index 13c8e948ca..7c1f74e488 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -229,7 +229,7 @@ class MessageService: model_manager = ModelManager() - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() if invoke_from == InvokeFrom.DEBUGGER: workflow = workflow_service.get_draft_workflow(app_model=app_model) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 05fa5a95bc..208ecdb79e 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -131,11 +131,11 @@ class MetadataService: @staticmethod def get_built_in_fields(): return [ - {"name": BuiltInField.document_name.value, "type": "string"}, - {"name": BuiltInField.uploader.value, "type": "string"}, - {"name": BuiltInField.upload_date.value, "type": "time"}, - {"name": BuiltInField.last_update_date.value, "type": "time"}, - {"name": BuiltInField.source.value, "type": "string"}, + {"name": BuiltInField.document_name, "type": "string"}, + {"name": BuiltInField.uploader, "type": "string"}, + {"name": BuiltInField.upload_date, "type": "time"}, + {"name": BuiltInField.last_update_date, "type": "time"}, + {"name": BuiltInField.source, "type": "string"}, ] @staticmethod @@ -153,11 +153,11 @@ class MetadataService: doc_metadata = {} else: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata[BuiltInField.document_name.value] = document.name - doc_metadata[BuiltInField.uploader.value] = document.uploader - doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() - doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() - doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value + doc_metadata[BuiltInField.document_name] = document.name + doc_metadata[BuiltInField.uploader] = document.uploader + doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp() + doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() + doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata db.session.add(document) dataset.built_in_field_enabled = True @@ -183,11 +183,11 @@ class MetadataService: doc_metadata = {} else: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata.pop(BuiltInField.document_name.value, None) - doc_metadata.pop(BuiltInField.uploader.value, None) - doc_metadata.pop(BuiltInField.upload_date.value, None) - doc_metadata.pop(BuiltInField.last_update_date.value, None) - doc_metadata.pop(BuiltInField.source.value, None) + doc_metadata.pop(BuiltInField.document_name, None) + doc_metadata.pop(BuiltInField.uploader, None) + doc_metadata.pop(BuiltInField.upload_date, None) + doc_metadata.pop(BuiltInField.last_update_date, None) + doc_metadata.pop(BuiltInField.source, None) document.doc_metadata = doc_metadata db.session.add(document) document_ids.append(document.id) @@ -211,11 +211,11 @@ class MetadataService: for metadata_value in operation.metadata_list: doc_metadata[metadata_value.name] = metadata_value.value if dataset.built_in_field_enabled: - doc_metadata[BuiltInField.document_name.value] = document.name - doc_metadata[BuiltInField.uploader.value] = document.uploader - doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() - doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() - doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value + doc_metadata[BuiltInField.document_name] = document.name + doc_metadata[BuiltInField.uploader] = document.uploader + doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp() + doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() + doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata db.session.add(document) db.session.commit() diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index c638087f63..33d7dacba0 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -3,7 +3,7 @@ import logging from json import JSONDecodeError from typing import Optional, Union -from sqlalchemy import or_ +from sqlalchemy import or_, select from constants import HIDDEN_VALUE from core.entities.provider_configuration import ProviderConfiguration @@ -165,7 +165,7 @@ class ModelLoadBalancingService: try: if load_balancing_config.encrypted_config: - credentials = json.loads(load_balancing_config.encrypted_config) + credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config) else: credentials = {} except JSONDecodeError: @@ -180,11 +180,13 @@ class ModelLoadBalancingService: for variable in credential_secret_variables: if variable in credentials: try: - credentials[variable] = encrypter.decrypt_token_with_decoding( - credentials.get(variable), # ty: ignore [invalid-argument-type] - decoding_rsa_key, - decoding_cipher_rsa, - ) + token_value = credentials.get(variable) + if isinstance(token_value, str): + credentials[variable] = encrypter.decrypt_token_with_decoding( + token_value, + decoding_rsa_key, + decoding_cipher_rsa, + ) except ValueError: pass @@ -320,16 +322,14 @@ class ModelLoadBalancingService: if not isinstance(configs, list): raise ValueError("Invalid load balancing configs") - current_load_balancing_configs = ( - db.session.query(LoadBalancingModelConfig) - .where( + current_load_balancing_configs = db.session.scalars( + select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) - .all() - ) + ).all() # id as key, config as value current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} @@ -345,8 +345,9 @@ class ModelLoadBalancingService: credential_id = config.get("credential_id") enabled = config.get("enabled") + credential_record: ProviderCredential | ProviderModelCredential | None = None + if credential_id: - credential_record: ProviderCredential | ProviderModelCredential | None = None if config_from == "predefined-model": credential_record = ( db.session.query(ProviderCredential) diff --git a/api/services/plugin/github_service.py b/api/services/plugin/github_service.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 8dbf117fd3..34f10ae407 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -99,6 +99,7 @@ class PluginMigration: datetime.timedelta(hours=1), ] + tenant_count = 0 for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) @@ -255,7 +256,7 @@ class PluginMigration: return [] agent_app_model_config_ids = [ - app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value + app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT ] rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all() diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index e19f53f120..a9733e0826 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -1,5 +1,7 @@ from typing import Optional +from sqlalchemy import select + from constants.languages import languages from extensions.ext_database import db from models.model import App, RecommendedApp @@ -31,18 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): :param language: language :return: """ - recommended_apps = ( - db.session.query(RecommendedApp) - .where(RecommendedApp.is_listed == True, RecommendedApp.language == language) - .all() - ) + recommended_apps = db.session.scalars( + select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language) + ).all() if len(recommended_apps) == 0: - recommended_apps = ( - db.session.query(RecommendedApp) - .where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) - .all() - ) + recommended_apps = db.session.scalars( + select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + ).all() categories = set() recommended_apps_result = [] diff --git a/api/services/tag_service.py b/api/services/tag_service.py index a16bdb46cd..dd67b19966 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -2,7 +2,7 @@ import uuid from typing import Optional from flask_login import current_user -from sqlalchemy import func +from sqlalchemy import func, select from werkzeug.exceptions import NotFound from extensions.ext_database import db @@ -29,35 +29,30 @@ class TagService: # Check if tag_ids is not empty to avoid WHERE false condition if not tag_ids or len(tag_ids) == 0: return [] - tags = ( - db.session.query(Tag) - .where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) - .all() - ) + tags = db.session.scalars( + select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + ).all() if not tags: return [] tag_ids = [tag.id for tag in tags] # Check if tag_ids is not empty to avoid WHERE false condition if not tag_ids or len(tag_ids) == 0: return [] - tag_bindings = ( - db.session.query(TagBinding.target_id) - .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) - .all() - ) - if not tag_bindings: - return [] - results = [tag_binding.target_id for tag_binding in tag_bindings] - return results + tag_bindings = db.session.scalars( + select(TagBinding.target_id).where( + TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id + ) + ).all() + return tag_bindings @staticmethod def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str): if not tag_type or not tag_name: return [] - tags = ( - db.session.query(Tag) - .where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) - .all() + tags = list( + db.session.scalars( + select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + ).all() ) if not tags: return [] @@ -117,7 +112,7 @@ class TagService: raise NotFound("Tag not found") db.session.delete(tag) # delete tag binding - tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all() + tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all() if tag_bindings: for tag_binding in tag_bindings: db.session.delete(tag_binding) diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 78e587abee..f86d7e51bf 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from typing import Any, cast from httpx import get +from sqlalchemy import select from core.entities.provider_entities import ProviderConfig from core.model_runtime.utils.encoders import jsonable_encoder @@ -443,9 +444,7 @@ class ApiToolManageService: list api tools """ # get all api providers - db_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or [] - ) + db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() result: list[ToolProviderApiEntity] = [] diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index bce389b949..cb31111485 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -223,8 +223,8 @@ class BuiltinToolManageService: """ add builtin tool provider """ - try: - with Session(db.engine) as session: + with Session(db.engine) as session: + try: lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" with redis_client.lock(lock, timeout=20): provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) @@ -285,9 +285,9 @@ class BuiltinToolManageService: session.add(db_provider) session.commit() - except Exception as e: - session.rollback() - raise ValueError(str(e)) + except Exception as e: + session.rollback() + raise ValueError(str(e)) return {"result": "success"} @staticmethod diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 2f8a91ed82..2449536d5c 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -3,7 +3,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any -from sqlalchemy import or_ +from sqlalchemy import or_, select from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController @@ -186,7 +186,9 @@ class WorkflowToolManageService: :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() + db_tools = db.session.scalars( + select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) + ).all() tools: list[WorkflowToolProviderController] = [] for provider in db_tools: diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 2994856b54..376f96c03a 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -18,6 +18,7 @@ from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.nodes import NodeType from events.app_event import app_was_created from extensions.ext_database import db @@ -64,7 +65,7 @@ class WorkflowConverter: new_app = App() new_app.tenant_id = app_model.tenant_id new_app.name = name or app_model.name + "(workflow)" - new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW new_app.icon_type = icon_type or app_model.icon_type new_app.icon = icon or app_model.icon new_app.icon_background = icon_background or app_model.icon_background @@ -202,7 +203,7 @@ class WorkflowConverter: app_mode_enum = AppMode.value_of(app_model.mode) app_config: EasyUIBasedAppConfig if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: - app_model.mode = AppMode.AGENT_CHAT.value + app_model.mode = AppMode.AGENT_CHAT app_config = AgentChatAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) @@ -278,7 +279,7 @@ class WorkflowConverter: "app_id": app_model.id, "tool_variable": tool_variable, "inputs": inputs, - "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "", + "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT else "", }, } @@ -420,7 +421,11 @@ class WorkflowConverter: query_in_prompt=False, ) - template = prompt_template_config["prompt_template"].template + prompt_template_obj = prompt_template_config["prompt_template"] + if not isinstance(prompt_template_obj, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}") + + template = prompt_template_obj.template if not template: prompts = [] else: @@ -457,7 +462,11 @@ class WorkflowConverter: query_in_prompt=False, ) - template = prompt_template_config["prompt_template"].template + prompt_template_obj = prompt_template_config["prompt_template"] + if not isinstance(prompt_template_obj, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}") + + template = prompt_template_obj.template template = self._replace_template_variables( template=template, variables=start_node["data"]["variables"], @@ -467,6 +476,9 @@ class WorkflowConverter: prompts = {"text": template} prompt_rules = prompt_template_config["prompt_rules"] + if not isinstance(prompt_rules, dict): + raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}") + role_prefix = { "user": prompt_rules.get("human_prefix", "Human"), "assistant": prompt_rules.get("assistant_prefix", "Assistant"), @@ -606,7 +618,7 @@ class WorkflowConverter: :param app_model: App instance :return: AppMode """ - if app_model.mode == AppMode.COMPLETION.value: + if app_model.mode == AppMode.COMPLETION: return AppMode.WORKFLOW else: return AppMode.ADVANCED_CHAT diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0a14007349..e680de3502 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -769,10 +769,10 @@ class WorkflowService: ) error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: - node = e._node + node = e.node run_succeeded = False node_run_result = None - error = e._error + error = e.error # Create a NodeExecution domain model node_execution = WorkflowNodeExecution( @@ -828,7 +828,7 @@ class WorkflowService: # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: + if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}: raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow @@ -844,11 +844,11 @@ class WorkflowService: return new_app def validate_features_structure(self, app_model: App, features: dict): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: return AdvancedChatAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: return WorkflowAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index d4fc68a084..292ac6e008 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -12,7 +12,7 @@ class WorkspaceService: def get_tenant_info(cls, tenant: Tenant): if not tenant: return None - tenant_info = { + tenant_info: dict[str, object] = { "id": tenant.id, "name": tenant.name, "plan": tenant.plan, diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 3498e08426..cdc07c77a8 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document @@ -39,7 +40,7 @@ def enable_annotation_reply_task( db.session.close() return - annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all() + annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all() enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 08e2c4a556..212f8c3c6a 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -34,7 +35,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form if not dataset: raise Exception("Document has no dataset") - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) + ).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -59,7 +62,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form db.session.commit() if file_ids: - files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() + files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() for file in files: try: storage.delete(file.key) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 9d12b6a589..5f2a355d16 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -55,8 +56,8 @@ def clean_dataset_task( index_struct=index_struct, collection_binding_id=collection_binding_id, ) - documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() - segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() + documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace # This ensures all invalid doc_form values are properly handled diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 6549ad04b5..761ac6fc3d 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -4,6 +4,7 @@ from typing import Optional import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -35,7 +36,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if not dataset: raise Exception("Document has no dataset") - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index e7a61e22f2..771b43f9b0 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -34,7 +35,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): document = db.session.query(Document).where(Document.id == document_id).first() db.session.delete(document) - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() index_node_ids = [segment.index_node_id for segment in segments] index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 23e929c57e..dc6ef6fb61 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -4,6 +4,7 @@ from typing import Literal import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -36,16 +37,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a if action == "remove": index_processor.clean(dataset, None, with_keywords=False) elif action == "add": - dataset_documents = ( - db.session.query(DatasetDocument) - .where( + dataset_documents = db.session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] @@ -89,16 +88,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) db.session.commit() elif action == "update": - dataset_documents = ( - db.session.query(DatasetDocument) - .where( + dataset_documents = db.session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() # add new index if dataset_documents: # update document status diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 0b750cf4db..e8cbd0f250 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -12,7 +12,9 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str): +def delete_segment_from_index_task( + index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None +): """ Async Remove segment from index :param index_node_ids: @@ -26,6 +28,7 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume try: dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: + logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) return dataset_document = db.session.query(Document).where(Document.id == document_id).first() @@ -33,11 +36,19 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info("Document not in valid state for index operations, skipping") return + doc_form = dataset_document.doc_form - index_type = dataset_document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + # Proceed with index cleanup using the index_node_ids directly + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean( + dataset, + index_node_ids, + with_keywords=True, + delete_child_chunks=True, + precomputed_child_node_ids=child_node_ids, + ) end_at = time.perf_counter() logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index d4899fe0e4..9038dc179b 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -44,15 +45,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen # sync index processor index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, ) - .all() - ) + ).all() if not segments: db.session.close() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 687e3e9551..24d7d16578 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor @@ -85,7 +86,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 48566b6104..161502a228 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -45,7 +46,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index d93f30ba37..2020179cd9 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -79,7 +80,9 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 647664641d..c5ca7a6171 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -45,15 +46,13 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i # sync index processor index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, ) - .all() - ) + ).all() if not segments: logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) db.session.close() diff --git a/api/tasks/mail_register_task.py b/api/tasks/mail_register_task.py new file mode 100644 index 0000000000..a9472a6119 --- /dev/null +++ b/api/tasks/mail_register_task.py @@ -0,0 +1,87 @@ +import logging +import time + +import click +from celery import shared_task + +from configs import dify_config +from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service + +logger = logging.getLogger(__name__) + + +@shared_task(queue="mail") +def send_email_register_mail_task(language: str, to: str, code: str) -> None: + """ + Send email register email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + code: Email register code + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start email register mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.EMAIL_REGISTER, + language_code=language, + to=to, + template_context={ + "to": to, + "code": code, + }, + ) + + end_at = time.perf_counter() + logger.info( + click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green") + ) + except Exception: + logger.exception("Send email register mail to %s failed", to) + + +@shared_task(queue="mail") +def send_email_register_mail_task_when_account_exist(language: str, to: str, account_name: str) -> None: + """ + Send email register email with internationalization support when account exist. + + Args: + language: Language code for email localization + to: Recipient email address + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start email register mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + login_url = f"{dify_config.CONSOLE_WEB_URL}/signin" + reset_password_url = f"{dify_config.CONSOLE_WEB_URL}/reset-password" + + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST, + language_code=language, + to=to, + template_context={ + "to": to, + "login_url": login_url, + "reset_password_url": reset_password_url, + "account_name": account_name, + }, + ) + + end_at = time.perf_counter() + logger.info( + click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green") + ) + except Exception: + logger.exception("Send email register mail to %s failed", to) diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 545db84fde..1739562588 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -4,6 +4,7 @@ import time import click from celery import shared_task +from configs import dify_config from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service @@ -44,3 +45,47 @@ def send_reset_password_mail_task(language: str, to: str, code: str): ) except Exception: logger.exception("Send password reset mail to %s failed", to) + + +@shared_task(queue="mail") +def send_reset_password_mail_task_when_account_not_exist(language: str, to: str, is_allow_register: bool) -> None: + """ + Send reset password email with internationalization support when account not exist. + + Args: + language: Language code for email localization + to: Recipient email address + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start password reset mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + if is_allow_register: + sign_up_url = f"{dify_config.CONSOLE_WEB_URL}/signup" + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST, + language_code=language, + to=to, + template_context={ + "to": to, + "sign_up_url": sign_up_url, + }, + ) + else: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER, + language_code=language, + to=to, + ) + + end_at = time.perf_counter() + logger.info( + click.style(f"Send password reset mail to {to} succeeded: latency: {end_at - start_at}", fg="green") + ) + except Exception: + logger.exception("Send password reset mail to %s failed", to) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index ec56ab583b..c0ab2d0b41 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -45,7 +46,7 @@ def remove_document_from_index_task(document_id: str): index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index c52218caae..b65eca7e0b 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -69,7 +70,9 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): # clean old data index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 3c7c69e3c8..0dc1d841f4 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -63,7 +64,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): # clean old data index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/templates/register_email_template_en-US.html b/api/templates/register_email_template_en-US.html new file mode 100644 index 0000000000..e0fec59100 --- /dev/null +++ b/api/templates/register_email_template_en-US.html @@ -0,0 +1,87 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Dify Sign-up Code

+

Your sign-up code for Dify + + Copy and paste this code, this code will only be valid for the next 5 minutes.

+
+ {{code}} +
+

If you didn't request this code, don't worry. You can safely ignore this email.

+
+ + + \ No newline at end of file diff --git a/api/templates/register_email_template_zh-CN.html b/api/templates/register_email_template_zh-CN.html new file mode 100644 index 0000000000..3b507290f0 --- /dev/null +++ b/api/templates/register_email_template_zh-CN.html @@ -0,0 +1,87 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Dify 注册验证码

+

您的 Dify 注册验证码 + + 复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。

+
+ {{code}} +
+

如果您没有请求,请不要担心。您可以安全地忽略此电子邮件。

+
+ + + \ No newline at end of file diff --git a/api/templates/register_email_when_account_exist_template_en-US.html b/api/templates/register_email_when_account_exist_template_en-US.html new file mode 100644 index 0000000000..ac5042c274 --- /dev/null +++ b/api/templates/register_email_when_account_exist_template_en-US.html @@ -0,0 +1,130 @@ + + + + + + + + +
+
+ + Dify Logo +
+

It looks like you’re signing up with an existing account

+

Hi, {{account_name}}

+

+ We noticed you tried to sign up, but this email is already registered with an existing account. + + Please log in here:

+ Log In +

+ If you forgot your password, you can reset it here: Reset Password +

+

+ If you didn’t request this action, you can safely ignore this email. +

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + diff --git a/api/templates/register_email_when_account_exist_template_zh-CN.html b/api/templates/register_email_when_account_exist_template_zh-CN.html new file mode 100644 index 0000000000..326b58343a --- /dev/null +++ b/api/templates/register_email_when_account_exist_template_zh-CN.html @@ -0,0 +1,127 @@ + + + + + + + + +
+
+ + Dify Logo +
+

您似乎正在使用现有账户注册

+

您好,{{account_name}}

+

+ 我们注意到您尝试注册,但此电子邮件已注册。 + + 请在此登录:

+ 登录 +

+ 如果您忘记了密码,可以在此重置: 重置密码 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + diff --git a/api/templates/reset_password_mail_when_account_not_exist_no_register_template_en-US.html b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_en-US.html new file mode 100644 index 0000000000..1c5253a239 --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_en-US.html @@ -0,0 +1,122 @@ + + + + + + + + +
+
+ + Dify Logo +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. +

+

If you didn’t request this action, you can safely ignore this email.

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + \ No newline at end of file diff --git a/api/templates/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html new file mode 100644 index 0000000000..1431291218 --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html @@ -0,0 +1,121 @@ + + + + + + + + +
+
+ + Dify Logo +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + \ No newline at end of file diff --git a/api/templates/reset_password_mail_when_account_not_exist_template_en-US.html b/api/templates/reset_password_mail_when_account_not_exist_template_en-US.html new file mode 100644 index 0000000000..5759d56f7c --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_template_en-US.html @@ -0,0 +1,124 @@ + + + + + + + + +
+
+ + Dify Logo +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. + + Please sign up here:

+ Sign Up +

If you didn’t request this action, you can safely ignore this email.

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + \ No newline at end of file diff --git a/api/templates/reset_password_mail_when_account_not_exist_template_zh-CN.html b/api/templates/reset_password_mail_when_account_not_exist_template_zh-CN.html new file mode 100644 index 0000000000..4de4a8abaa --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_template_zh-CN.html @@ -0,0 +1,126 @@ + + + + + + + + +
+
+ + Dify Logo +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。 + + 请在此注册:

+

+ 注册 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + \ No newline at end of file diff --git a/api/templates/without-brand/register_email_template_en-US.html b/api/templates/without-brand/register_email_template_en-US.html new file mode 100644 index 0000000000..bd67c8ff4a --- /dev/null +++ b/api/templates/without-brand/register_email_template_en-US.html @@ -0,0 +1,83 @@ + + + + + + + + +
+

{{application_title}} Sign-up Code

+

Your sign-up code + + Copy and paste this code, this code will only be valid for the next 5 minutes.

+
+ {{code}} +
+

If you didn't request this code, don't worry. You can safely ignore this email.

+
+ + + diff --git a/api/templates/without-brand/register_email_template_zh-CN.html b/api/templates/without-brand/register_email_template_zh-CN.html new file mode 100644 index 0000000000..26df4760aa --- /dev/null +++ b/api/templates/without-brand/register_email_template_zh-CN.html @@ -0,0 +1,83 @@ + + + + + + + + +
+

{{application_title}} 注册验证码

+

您的 {{application_title}} 注册验证码 + + 复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。

+
+ {{code}} +
+

如果您没有请求此验证码,请不要担心。您可以安全地忽略此电子邮件。

+
+ + + \ No newline at end of file diff --git a/api/templates/without-brand/register_email_when_account_exist_template_en-US.html b/api/templates/without-brand/register_email_when_account_exist_template_en-US.html new file mode 100644 index 0000000000..2e74956e14 --- /dev/null +++ b/api/templates/without-brand/register_email_when_account_exist_template_en-US.html @@ -0,0 +1,126 @@ + + + + + + + + +
+

It looks like you’re signing up with an existing account

+

Hi, {{account_name}}

+

+ We noticed you tried to sign up, but this email is already registered with an existing account. + + Please log in here:

+ Log In +

+ If you forgot your password, you can reset it here: Reset Password +

+

+ If you didn’t request this action, you can safely ignore this email. +

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + diff --git a/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html b/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html new file mode 100644 index 0000000000..a315f9154d --- /dev/null +++ b/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html @@ -0,0 +1,123 @@ + + + + + + + + +
+

您似乎正在使用现有账户注册

+

您好,{{account_name}}

+

+ 我们注意到您尝试注册,但此电子邮件已注册。 + + 请在此登录:

+ 登录 +

+ 如果您忘记了密码,可以在此重置: 重置密码 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html new file mode 100644 index 0000000000..ae59f36332 --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html @@ -0,0 +1,118 @@ + + + + + + + + +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. +

+

If you didn’t request this action, you can safely ignore this email.

+
+
Please do not reply directly to this email, it is automatically sent by the system.
s + + + diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html new file mode 100644 index 0000000000..4b4fda2c6e --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html @@ -0,0 +1,118 @@ + + + + + + + + +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_en-US.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_en-US.html new file mode 100644 index 0000000000..fedc998809 --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_en-US.html @@ -0,0 +1,121 @@ + + + + + + + + +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. + + Please sign up here:

+ Sign Up +

If you didn’t request this action, you can safely ignore this email.

+ +
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + \ No newline at end of file diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html new file mode 100644 index 0000000000..2464b4a058 --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html @@ -0,0 +1,120 @@ + + + + + + + + +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。 + + 请在此注册:

+ 注册 +

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + \ No newline at end of file diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 2e98dec964..92df93fb13 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -203,6 +203,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5 CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py new file mode 100644 index 0000000000..524713fbf1 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -0,0 +1,101 @@ +"""Integration tests for ChatMessageApi permission verification.""" + +import uuid +from unittest import mock + +import pytest +from flask.testing import FlaskClient + +from controllers.console.app import completion as completion_api +from controllers.console.app import wraps +from libs.datetime_utils import naive_utc_now +from models import Account, App, Tenant +from models.account import TenantAccountRole +from models.model import AppMode +from services.app_generate_service import AppGenerateService + + +class TestChatMessageApiPermissions: + """Test permission verification for ChatMessageApi endpoint.""" + + @pytest.fixture + def mock_app_model(self): + """Create a mock App model for testing.""" + app = App() + app.id = str(uuid.uuid4()) + app.mode = AppMode.CHAT.value + app.tenant_id = str(uuid.uuid4()) + app.status = "normal" + return app + + @pytest.fixture + def mock_account(self): + """Create a mock Account for testing.""" + + account = Account() + account.id = str(uuid.uuid4()) + account.name = "Test User" + account.email = "test@example.com" + account.last_active_at = naive_utc_now() + account.created_at = naive_utc_now() + account.updated_at = naive_utc_now() + + # Create mock tenant + tenant = Tenant() + tenant.id = str(uuid.uuid4()) + tenant.name = "Test Tenant" + + account._current_tenant = tenant + return account + + @pytest.mark.parametrize( + ("role", "status"), + [ + (TenantAccountRole.OWNER, 200), + (TenantAccountRole.ADMIN, 200), + (TenantAccountRole.EDITOR, 200), + (TenantAccountRole.NORMAL, 403), + (TenantAccountRole.DATASET_OPERATOR, 403), + ], + ) + def test_post_with_owner_role_succeeds( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_app_model, + mock_account, + role: TenantAccountRole, + status: int, + ): + """Test that OWNER role can access chat-messages endpoint.""" + + """Setup common mocks for testing.""" + # Mock app loading + + mock_load_app_model = mock.Mock(return_value=mock_app_model) + monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + + # Mock current user + monkeypatch.setattr(completion_api, "current_user", mock_account) + + mock_generate = mock.Mock(return_value={"message": "Test response"}) + monkeypatch.setattr(AppGenerateService, "generate", mock_generate) + + # Set user role to OWNER + mock_account.role = role + + response = test_client.post( + f"/console/api/apps/{mock_app_model.id}/chat-messages", + headers=auth_header, + json={ + "inputs": {}, + "query": "Hello, world!", + "model_config": { + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}} + }, + "response_mode": "blocking", + }, + ) + + assert response.status_code == status diff --git a/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py new file mode 100644 index 0000000000..ca4d452963 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py @@ -0,0 +1,129 @@ +"""Integration tests for ModelConfigResource permission verification.""" + +import uuid +from unittest import mock + +import pytest +from flask.testing import FlaskClient + +from controllers.console.app import model_config as model_config_api +from controllers.console.app import wraps +from libs.datetime_utils import naive_utc_now +from models import Account, App, Tenant +from models.account import TenantAccountRole +from models.model import AppMode +from services.app_model_config_service import AppModelConfigService + + +class TestModelConfigResourcePermissions: + """Test permission verification for ModelConfigResource endpoint.""" + + @pytest.fixture + def mock_app_model(self): + """Create a mock App model for testing.""" + app = App() + app.id = str(uuid.uuid4()) + app.mode = AppMode.CHAT.value + app.tenant_id = str(uuid.uuid4()) + app.status = "normal" + app.app_model_config_id = str(uuid.uuid4()) + return app + + @pytest.fixture + def mock_account(self): + """Create a mock Account for testing.""" + + account = Account() + account.id = str(uuid.uuid4()) + account.name = "Test User" + account.email = "test@example.com" + account.last_active_at = naive_utc_now() + account.created_at = naive_utc_now() + account.updated_at = naive_utc_now() + + # Create mock tenant + tenant = Tenant() + tenant.id = str(uuid.uuid4()) + tenant.name = "Test Tenant" + + account._current_tenant = tenant + return account + + @pytest.mark.parametrize( + ("role", "status"), + [ + (TenantAccountRole.OWNER, 200), + (TenantAccountRole.ADMIN, 200), + (TenantAccountRole.EDITOR, 200), + (TenantAccountRole.NORMAL, 403), + (TenantAccountRole.DATASET_OPERATOR, 403), + ], + ) + def test_post_with_owner_role_succeeds( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_app_model, + mock_account, + role: TenantAccountRole, + status: int, + ): + """Test that OWNER role can access model-config endpoint.""" + # Set user role to OWNER + mock_account.role = role + + # Mock app loading + mock_load_app_model = mock.Mock(return_value=mock_app_model) + monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + + # Mock current user + monkeypatch.setattr(model_config_api, "current_user", mock_account) + + # Mock AccountService.load_user to prevent authentication issues + from services.account_service import AccountService + + mock_load_user = mock.Mock(return_value=mock_account) + monkeypatch.setattr(AccountService, "load_user", mock_load_user) + + mock_validate_config = mock.Mock( + return_value={ + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}}, + "pre_prompt": "You are a helpful assistant.", + "user_input_form": [], + "dataset_query_variable": "", + "agent_mode": {"enabled": False, "tools": []}, + } + ) + monkeypatch.setattr(AppModelConfigService, "validate_configuration", mock_validate_config) + + # Mock database operations + mock_db_session = mock.Mock() + mock_db_session.add = mock.Mock() + mock_db_session.flush = mock.Mock() + mock_db_session.commit = mock.Mock() + monkeypatch.setattr(model_config_api.db, "session", mock_db_session) + + # Mock app_model_config_was_updated event + mock_event = mock.Mock() + mock_event.send = mock.Mock() + monkeypatch.setattr(model_config_api, "app_model_config_was_updated", mock_event) + + response = test_client.post( + f"/console/api/apps/{mock_app_model.id}/model-config", + headers=auth_header, + json={ + "model": { + "provider": "openai", + "name": "gpt-4", + "mode": "chat", + "completion_params": {"temperature": 0.7, "max_tokens": 1000}, + }, + "user_input_form": [], + "dataset_query_variable": "", + "pre_prompt": "You are a helpful assistant.", + "agent_mode": {"enabled": False, "tools": []}, + }, + ) + + assert response.status_code == status diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 415e65ce51..c98406d845 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -13,7 +13,6 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, - AccountNotFoundError, AccountPasswordError, AccountRegisterError, CurrentPasswordIncorrectError, @@ -91,6 +90,28 @@ class TestAccountService: assert account.password is None assert account.password_salt is None + def test_create_account_password_invalid_new_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account create with invalid new password format. + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Test with too short password (assuming minimum length validation) + with pytest.raises(ValueError): # Password validation error + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password="invalid_new_password", + ) + def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies): """ Test account creation when registration is disabled. @@ -139,7 +160,7 @@ class TestAccountService: fake = Faker() email = fake.email() password = fake.password(length=12) - with pytest.raises(AccountNotFoundError): + with pytest.raises(AccountPasswordError): AccountService.authenticate(email, password) def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies): @@ -940,7 +961,8 @@ class TestAccountService: Test getting user through non-existent email. """ fake = Faker() - non_existent_email = fake.email() + domain = f"test-{fake.random_letters(10)}.com" + non_existent_email = fake.email(domain=domain) found_user = AccountService.get_user_through_email(non_existent_email) assert found_user is None @@ -3278,7 +3300,7 @@ class TestRegisterService: redis_client.setex(cache_key, 24 * 60 * 60, account_id) # Execute invitation retrieval - result = RegisterService._get_invitation_by_token( + result = RegisterService.get_invitation_by_token( token=token, workspace_id=workspace_id, email=email, @@ -3316,7 +3338,7 @@ class TestRegisterService: redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) # Execute invitation retrieval - result = RegisterService._get_invitation_by_token(token=token) + result = RegisterService.get_invitation_by_token(token=token) # Verify result contains expected data assert result is not None diff --git a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py index 9ed9008af9..3ec265d009 100644 --- a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py +++ b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py @@ -42,7 +42,7 @@ class TestAdvancedPromptTemplateService: # Test data for Baichuan model args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", @@ -77,7 +77,7 @@ class TestAdvancedPromptTemplateService: # Test data for common model args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -116,7 +116,7 @@ class TestAdvancedPromptTemplateService: for model_name in test_cases: args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": model_name, "has_context": "true", @@ -144,7 +144,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -173,7 +173,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -202,7 +202,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -230,7 +230,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -257,7 +257,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "false") # Assert: Verify the expected outcomes assert result is not None @@ -303,7 +303,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "unsupported_mode", "true") # Assert: Verify empty dict is returned assert result == {} @@ -442,7 +442,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -473,7 +473,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -502,7 +502,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -530,7 +530,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -557,7 +557,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") # Assert: Verify the expected outcomes assert result is not None @@ -603,7 +603,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "unsupported_mode", "true") # Assert: Verify empty dict is returned assert result == {} @@ -621,7 +621,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Test all app modes - app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value] + app_modes = [AppMode.CHAT, AppMode.COMPLETION] model_modes = ["completion", "chat"] for app_mode in app_modes: @@ -653,7 +653,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Test all app modes - app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value] + app_modes = [AppMode.CHAT, AppMode.COMPLETION] model_modes = ["completion", "chat"] for app_mode in app_modes: @@ -686,10 +686,10 @@ class TestAdvancedPromptTemplateService: # Test edge cases edge_cases = [ {"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"}, - {"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"}, - {"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"}, + {"app_mode": AppMode.CHAT, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"}, + {"app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "", "has_context": "true"}, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "", @@ -723,7 +723,7 @@ class TestAdvancedPromptTemplateService: # Test with context args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -757,7 +757,7 @@ class TestAdvancedPromptTemplateService: # Test with context args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", @@ -786,25 +786,25 @@ class TestAdvancedPromptTemplateService: # Test different scenarios test_scenarios = [ { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "chat", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "chat", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -843,25 +843,25 @@ class TestAdvancedPromptTemplateService: # Test different scenarios test_scenarios = [ { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "chat", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "chat", "model_name": "baichuan-13b-chat", "has_context": "true", diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index 4646531a4e..d0f7e945f1 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -255,7 +255,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Try to create metadata with built-in field name - built_in_field_name = BuiltInField.document_name.value + built_in_field_name = BuiltInField.document_name metadata_args = MetadataArgs(type="string", name=built_in_field_name) # Act & Assert: Verify proper error handling @@ -375,7 +375,7 @@ class TestMetadataService: metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Try to update with built-in field name - built_in_field_name = BuiltInField.document_name.value + built_in_field_name = BuiltInField.document_name with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) @@ -540,11 +540,11 @@ class TestMetadataService: field_names = [field["name"] for field in result] field_types = [field["type"] for field in result] - assert BuiltInField.document_name.value in field_names - assert BuiltInField.uploader.value in field_names - assert BuiltInField.upload_date.value in field_names - assert BuiltInField.last_update_date.value in field_names - assert BuiltInField.source.value in field_names + assert BuiltInField.document_name in field_names + assert BuiltInField.uploader in field_names + assert BuiltInField.upload_date in field_names + assert BuiltInField.last_update_date in field_names + assert BuiltInField.source in field_names # Verify field types assert "string" in field_types @@ -682,11 +682,11 @@ class TestMetadataService: # Set document metadata with built-in fields document.doc_metadata = { - BuiltInField.document_name.value: document.name, - BuiltInField.uploader.value: "test_uploader", - BuiltInField.upload_date.value: 1234567890.0, - BuiltInField.last_update_date.value: 1234567890.0, - BuiltInField.source.value: "test_source", + BuiltInField.document_name: document.name, + BuiltInField.uploader: "test_uploader", + BuiltInField.upload_date: 1234567890.0, + BuiltInField.last_update_date: 1234567890.0, + BuiltInField.source: "test_source", } db.session.add(document) db.session.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index cb20238f0c..66527dd506 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from models.account import TenantAccountJoin, TenantAccountRole from models.model import Account, Tenant @@ -468,7 +469,7 @@ class TestModelLoadBalancingService: assert load_balancing_config.id is not None # Verify inherit config was created in database - inherit_configs = ( - db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all() - ) + inherit_configs = db.session.scalars( + select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__") + ).all() assert len(inherit_configs) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index d09a4a17ab..04cff397b2 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy import select from werkzeug.exceptions import NotFound from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -954,7 +955,9 @@ class TestTagService: from extensions.ext_database import db # Verify only one binding exists - bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() + bindings = db.session.scalars( + select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) + ).all() assert len(bindings) == 1 def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): @@ -1064,7 +1067,9 @@ class TestTagService: # No error should be raised, and database state should remain unchanged from extensions.ext_database import db - bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() + bindings = db.session.scalars( + select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) + ).all() assert len(bindings) == 0 def test_check_target_exists_knowledge_success( diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 6d6f1dab72..c9ace46c55 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account @@ -354,16 +355,14 @@ class TestWebConversationService: # Verify only one pinned conversation record exists from extensions.ext_database import db - pinned_conversations = ( - db.session.query(PinnedConversation) - .where( + pinned_conversations = db.session.scalars( + select(PinnedConversation).where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, PinnedConversation.created_by_role == "account", PinnedConversation.created_by == account.id, ) - .all() - ) + ).all() assert len(pinned_conversations) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 429056f5e2..316cfe1674 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -1,3 +1,5 @@ +import time +import uuid from unittest.mock import patch import pytest @@ -248,9 +250,15 @@ class TestWebAppAuthService: - Proper error handling for non-existent accounts - Correct exception type and message """ - # Arrange: Use non-existent email - fake = Faker() - non_existent_email = fake.email() + # Arrange: Generate a guaranteed non-existent email + # Use UUID and timestamp to ensure uniqueness + unique_id = str(uuid.uuid4()).replace("-", "") + timestamp = str(int(time.time() * 1000000)) # microseconds + non_existent_email = f"nonexistent_{unique_id}_{timestamp}@test-domain-that-never-exists.invalid" + + # Double-check this email doesn't exist in the database + existing_account = db_session_with_containers.query(Account).filter_by(email=non_existent_email).first() + assert existing_account is None, f"Test email {non_existent_email} already exists in database" # Act & Assert: Verify proper error handling with pytest.raises(AccountNotFoundError): diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index 018eb6d896..b61df18b90 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -96,7 +96,7 @@ class TestWorkflowService: app.tenant_id = fake.uuid4() app.name = fake.company() app.description = fake.text() - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW app.icon_type = "emoji" app.icon = "🤖" app.icon_background = "#FFEAD5" @@ -883,7 +883,7 @@ class TestWorkflowService: # Create chat mode app app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT # Create app model config (required for conversion) from models.model import AppModelConfig @@ -926,7 +926,7 @@ class TestWorkflowService: # Assert assert result is not None - assert result.mode == AppMode.ADVANCED_CHAT.value # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW + assert result.mode == AppMode.ADVANCED_CHAT # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW assert result.name == conversion_args["name"] assert result.icon == conversion_args["icon"] assert result.icon_type == conversion_args["icon_type"] @@ -945,7 +945,7 @@ class TestWorkflowService: # Create completion mode app app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.COMPLETION.value + app.mode = AppMode.COMPLETION # Create app model config (required for conversion) from models.model import AppModelConfig @@ -988,7 +988,7 @@ class TestWorkflowService: # Assert assert result is not None - assert result.mode == AppMode.WORKFLOW.value + assert result.mode == AppMode.WORKFLOW assert result.name == conversion_args["name"] assert result.icon == conversion_args["icon"] assert result.icon_type == conversion_args["icon_type"] @@ -1007,7 +1007,7 @@ class TestWorkflowService: # Create workflow mode app (already in workflow mode) app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW from extensions.ext_database import db @@ -1030,7 +1030,7 @@ class TestWorkflowService: # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.ADVANCED_CHAT.value + app.mode = AppMode.ADVANCED_CHAT from extensions.ext_database import db @@ -1061,7 +1061,7 @@ class TestWorkflowService: # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW from extensions.ext_database import db diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 8b3db27525..18ab4bb73c 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -14,6 +14,7 @@ from core.app.app_config.entities import ( VariableEntityType, ) from core.model_runtime.entities.llm_entities import LLMMode +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.account import Account, Tenant from models.api_based_extension import APIBasedExtension from models.model import App, AppMode, AppModelConfig @@ -37,7 +38,7 @@ class TestWorkflowConverter: # Setup default mock returns mock_encrypter.decrypt_token.return_value = "decrypted_api_key" mock_prompt_transform.return_value.get_prompt_template.return_value = { - "prompt_template": type("obj", (object,), {"template": "You are a helpful assistant {{text_input}}"})(), + "prompt_template": PromptTemplateParser(template="You are a helpful assistant {{text_input}}"), "prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"}, } mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config() diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py new file mode 100644 index 0000000000..de81295100 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -0,0 +1,1099 @@ +""" +Integration tests for create_segment_to_index_task using TestContainers. + +This module provides comprehensive testing for the create_segment_to_index_task +which handles asynchronous document segment indexing operations. +""" + +import time +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from faker import Faker + +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.create_segment_to_index_task import create_segment_to_index_task + + +class TestCreateSegmentToIndexTask: + """Integration tests for create_segment_to_index_task using testcontainers.""" + + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers): + """Clean up database and Redis before each test to ensure isolation.""" + from extensions.ext_database import db + + # Clear all test data + db.session.query(DocumentSegment).delete() + db.session.query(Document).delete() + db.session.query(Dataset).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Tenant).delete() + db.session.query(Account).delete() + db.session.commit() + + # Clear Redis cache + redis_client.flushdb() + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.create_segment_to_index_task.IndexProcessorFactory") as mock_factory, + ): + # Setup default mock returns + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + yield { + "index_processor_factory": mock_factory, + "index_processor": mock_processor, + } + + def _create_test_account_and_tenant(self, db_session_with_containers): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + plan="basic", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join with owner role + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_dataset_and_document(self, db_session_with_containers, tenant_id, account_id): + """ + Helper method to create a test dataset and document for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant_id: Tenant ID for the dataset + account_id: Account ID for the document + + Returns: + tuple: (dataset, document) - Created dataset and document instances + """ + fake = Faker() + + # Create dataset + dataset = Dataset( + name=fake.company(), + description=fake.text(max_nb_chars=100), + tenant_id=tenant_id, + data_source_type="upload_file", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + created_by=account_id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + # Create document + document = Document( + name=fake.file_name(), + dataset_id=dataset.id, + tenant_id=tenant_id, + position=1, + data_source_type="upload_file", + batch="test_batch", + created_from="upload_file", + created_by=account_id, + enabled=True, + archived=False, + indexing_status="completed", + doc_form="qa_model", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + return dataset, document + + def _create_test_segment( + self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status="waiting" + ): + """ + Helper method to create a test document segment for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset_id: Dataset ID for the segment + document_id: Document ID for the segment + tenant_id: Tenant ID for the segment + account_id: Account ID for the segment + status: Initial status of the segment + + Returns: + DocumentSegment: Created document segment instance + """ + fake = Faker() + + segment = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + position=1, + content=fake.text(max_nb_chars=500), + answer=fake.text(max_nb_chars=200), + word_count=len(fake.text(max_nb_chars=500).split()), + tokens=len(fake.text(max_nb_chars=500).split()) * 2, + keywords=["test", "document", "segment"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status=status, + created_by=account_id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + return segment + + def test_create_segment_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful creation of segment to index. + + This test verifies: + - Segment status transitions from waiting to indexing to completed + - Index processor is called with correct parameters + - Segment metadata is properly updated + - Redis cache key is cleaned up + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify segment status changes + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + assert segment.error is None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + # Verify Redis cache cleanup + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 + + def test_create_segment_to_index_segment_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent segment ID. + + This test verifies: + - Task gracefully handles missing segment + - No exceptions are raised + - Database session is properly closed + """ + # Arrange: Use non-existent segment ID + non_existent_segment_id = str(uuid4()) + + # Act & Assert: Task should complete without error + result = create_segment_to_index_task(non_existent_segment_id) + assert result is None + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_invalid_status( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with invalid status. + + This test verifies: + - Task skips segments not in 'waiting' status + - No processing occurs for invalid status + - Database session is properly closed + """ + # Arrange: Create segment with invalid status + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="completed" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status unchanged + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is None + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_no_dataset(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test handling of segment without associated dataset. + + This test verifies: + - Task gracefully handles missing dataset + - Segment status remains unchanged + - No processing occurs + """ + # Arrange: Create segment with invalid dataset_id + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + invalid_dataset_id = str(uuid4()) + + # Create document with invalid dataset_id + document = Document( + name="test_doc", + dataset_id=invalid_dataset_id, + tenant_id=tenant.id, + position=1, + data_source_type="upload_file", + batch="test_batch", + created_from="upload_file", + created_by=account.id, + enabled=True, + archived=False, + indexing_status="completed", + doc_form="text_model", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, invalid_dataset_id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_no_document(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test handling of segment without associated document. + + This test verifies: + - Task gracefully handles missing document + - Segment status remains unchanged + - No processing occurs + """ + # Arrange: Create segment with invalid document_id + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, _ = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + invalid_document_id = str(uuid4()) + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, invalid_document_id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with disabled document. + + This test verifies: + - Task skips segments with disabled documents + - No processing occurs for disabled documents + - Segment status remains unchanged + """ + # Arrange: Create disabled document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Disable the document + document.enabled = False + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_archived( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with archived document. + + This test verifies: + - Task skips segments with archived documents + - No processing occurs for archived documents + - Segment status remains unchanged + """ + # Arrange: Create archived document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Archive the document + document.archived = True + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_indexing_incomplete( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with document that has incomplete indexing. + + This test verifies: + - Task skips segments with incomplete indexing documents + - No processing occurs for incomplete indexing + - Segment status remains unchanged + """ + # Arrange: Create document with incomplete indexing + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Set incomplete indexing status + document.indexing_status = "indexing" + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_processor_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of index processor exceptions. + + This test verifies: + - Task properly handles index processor failures + - Segment status is updated to error + - Segment is disabled with error information + - Redis cache is cleaned up despite errors + """ + # Arrange: Create test data and mock processor exception + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Mock processor to raise exception + mock_external_service_dependencies["index_processor"].load.side_effect = Exception("Processor failed") + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify error handling + db_session_with_containers.refresh(segment) + assert segment.status == "error" + assert segment.enabled is False + assert segment.disabled_at is not None + assert segment.error == "Processor failed" + + # Verify Redis cache cleanup still occurs + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 + + def test_create_segment_to_index_with_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with custom keywords. + + This test verifies: + - Task accepts and processes keywords parameter + - Keywords are properly passed through the task + - Indexing completes successfully with keywords + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + custom_keywords = ["custom", "keywords", "test"] + + # Act: Execute the task with keywords + create_segment_to_index_task(segment.id, keywords=custom_keywords) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_different_doc_forms( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with different document forms. + + This test verifies: + - Task works with various document forms + - Index processor factory receives correct doc_form + - Processing completes successfully for different forms + """ + # Arrange: Test different doc_forms + doc_forms = ["qa_model", "text_model", "web_model"] + + for doc_form in doc_forms: + # Create fresh test data for each form + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document( + db_session_with_containers, tenant.id, account.id + ) + + # Update document's doc_form for testing + document.doc_form = doc_form + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + # Verify correct doc_form was passed to factory + mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form) + + def test_create_segment_to_index_performance_timing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing performance and timing. + + This test verifies: + - Task execution time is reasonable + - Performance metrics are properly recorded + - No significant performance degradation + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task and measure time + start_time = time.time() + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify performance + execution_time = end_time - start_time + assert execution_time < 5.0 # Should complete within 5 seconds + + # Verify successful completion + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + def test_create_segment_to_index_concurrent_execution( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test concurrent execution of segment indexing tasks. + + This test verifies: + - Multiple tasks can run concurrently + - No race conditions occur + - All segments are processed correctly + """ + # Arrange: Create multiple test segments + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + segments = [] + for i in range(3): + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + segments.append(segment) + + # Act: Execute tasks concurrently (simulated) + segment_ids = [segment.id for segment in segments] + for segment_id in segment_ids: + create_segment_to_index_task(segment_id) + + # Assert: Verify all segments processed + for segment in segments: + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called for each segment + assert mock_external_service_dependencies["index_processor_factory"].call_count == 3 + + def test_create_segment_to_index_large_content( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with large content. + + This test verifies: + - Task handles large content segments + - Performance remains acceptable with large content + - No memory or processing issues occur + """ + # Arrange: Create segment with large content + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Generate large content (simulate large document) + large_content = "Large content " * 1000 # ~15KB content + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content=large_content, + answer="Large answer " * 100, + word_count=len(large_content.split()), + tokens=len(large_content.split()) * 2, + keywords=["large", "content", "test"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + start_time = time.time() + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify successful processing + execution_time = end_time - start_time + assert execution_time < 10.0 # Should complete within 10 seconds + + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_redis_failure( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing when Redis operations fail. + + This test verifies: + - Task continues to work even if Redis fails + - Indexing completes successfully + - Redis errors don't affect core functionality + """ + # Arrange: Create test data and mock Redis failure + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Set up Redis cache key to simulate indexing in progress + cache_key = f"segment_{segment.id}_indexing" + redis_client.set(cache_key, "processing", ex=300) + + # Mock Redis to raise exception in finally block + with patch.object(redis_client, "delete", side_effect=Exception("Redis connection failed")): + # Act: Execute the task - Redis failure should not prevent completion + with pytest.raises(Exception) as exc_info: + create_segment_to_index_task(segment.id) + + # Verify the exception contains the expected Redis error message + assert "Redis connection failed" in str(exc_info.value) + + # Assert: Verify indexing still completed successfully despite Redis failure + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify Redis cache key still exists (since delete failed) + assert redis_client.exists(cache_key) == 1 + + def test_create_segment_to_index_database_transaction_rollback( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with database transaction handling. + + This test verifies: + - Database transactions are properly managed + - Rollback occurs on errors + - Data consistency is maintained + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Mock global database session to simulate transaction issues + from extensions.ext_database import db + + original_commit = db.session.commit + commit_called = False + + def mock_commit(): + nonlocal commit_called + if not commit_called: + commit_called = True + raise Exception("Database commit failed") + return original_commit() + + db.session.commit = mock_commit + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify error handling and rollback + db_session_with_containers.refresh(segment) + assert segment.status == "error" + assert segment.enabled is False + assert segment.disabled_at is not None + assert segment.error is not None + + # Restore original commit method + db.session.commit = original_commit + + def test_create_segment_to_index_metadata_validation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with metadata validation. + + This test verifies: + - Document metadata is properly constructed + - All required metadata fields are present + - Metadata is correctly passed to index processor + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + # Verify index processor was called with correct metadata + mock_processor = mock_external_service_dependencies["index_processor"] + mock_processor.load.assert_called_once() + + # Get the call arguments to verify metadata structure + call_args = mock_processor.load.call_args + assert len(call_args[0]) == 2 # dataset and documents + + # Verify basic structure without deep object inspection + called_dataset = call_args[0][0] # first arg should be dataset + assert called_dataset is not None + + documents = call_args[0][1] # second arg should be list of documents + assert len(documents) == 1 + doc = documents[0] + assert doc is not None + + def test_create_segment_to_index_status_transition_flow( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test complete status transition flow during indexing. + + This test verifies: + - Status transitions: waiting -> indexing -> completed + - Timestamps are properly recorded at each stage + - No intermediate states are skipped + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Verify initial state + assert segment.status == "waiting" + assert segment.indexing_at is None + assert segment.completed_at is None + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify final state + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify timestamp ordering + assert segment.indexing_at <= segment.completed_at + + def test_create_segment_to_index_with_empty_content( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with empty or minimal content. + + This test verifies: + - Task handles empty content gracefully + - Indexing completes successfully with minimal content + - No errors occur with edge case content + """ + # Arrange: Create segment with minimal content + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content="", # Empty content + answer="", + word_count=0, + tokens=0, + keywords=[], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_with_special_characters( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with special characters and unicode content. + + This test verifies: + - Task handles special characters correctly + - Unicode content is processed properly + - No encoding issues occur + """ + # Arrange: Create segment with special characters + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~" + unicode_content = "Unicode: 中文测试 🚀 🌟 💻" + mixed_content = special_content + "\n" + unicode_content + + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content=mixed_content, + answer="Special answer: 🎯", + word_count=len(mixed_content.split()), + tokens=len(mixed_content.split()) * 2, + keywords=["special", "unicode", "test"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_with_long_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with long keyword lists. + + This test verifies: + - Task handles long keyword lists + - Keywords parameter is properly processed + - No performance issues with large keyword sets + """ + # Arrange: Create segment with long keywords + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Create long keyword list + long_keywords = [f"keyword_{i}" for i in range(100)] + + # Act: Execute the task with long keywords + create_segment_to_index_task(segment.id, keywords=long_keywords) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_tenant_isolation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with proper tenant isolation. + + This test verifies: + - Tasks are properly isolated by tenant + - No cross-tenant data access occurs + - Tenant boundaries are respected + """ + # Arrange: Create multiple tenants with segments + account1, tenant1 = self._create_test_account_and_tenant(db_session_with_containers) + account2, tenant2 = self._create_test_account_and_tenant(db_session_with_containers) + + dataset1, document1 = self._create_test_dataset_and_document( + db_session_with_containers, tenant1.id, account1.id + ) + dataset2, document2 = self._create_test_dataset_and_document( + db_session_with_containers, tenant2.id, account2.id + ) + + segment1 = self._create_test_segment( + db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status="waiting" + ) + segment2 = self._create_test_segment( + db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status="waiting" + ) + + # Act: Execute tasks for both tenants + create_segment_to_index_task(segment1.id) + create_segment_to_index_task(segment2.id) + + # Assert: Verify both segments processed independently + db_session_with_containers.refresh(segment1) + db_session_with_containers.refresh(segment2) + + assert segment1.status == "completed" + assert segment2.status == "completed" + assert segment1.tenant_id == tenant1.id + assert segment2.tenant_id == tenant2.id + assert segment1.tenant_id != segment2.tenant_id + + def test_create_segment_to_index_with_none_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with None keywords parameter. + + This test verifies: + - Task handles None keywords gracefully + - Default behavior works correctly + - No errors occur with None parameters + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task with None keywords + create_segment_to_index_task(segment.id, keywords=None) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_comprehensive_integration( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Comprehensive integration test covering multiple scenarios. + + This test verifies: + - Complete workflow from creation to completion + - All components work together correctly + - End-to-end functionality is maintained + - Performance and reliability under normal conditions + """ + # Arrange: Create comprehensive test scenario + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Create multiple segments with different characteristics + segments = [] + for i in range(5): + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + segments.append(segment) + + # Act: Process all segments + start_time = time.time() + for segment in segments: + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify comprehensive success + total_time = end_time - start_time + assert total_time < 25.0 # Should complete all within 25 seconds + + # Verify all segments processed successfully + for segment in segments: + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + assert segment.error is None + + # Verify index processor was called for each segment + expected_calls = len(segments) + assert mock_external_service_dependencies["index_processor_factory"].call_count == expected_calls + + # Verify Redis cleanup for each segment + for segment in segments: + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py new file mode 100644 index 0000000000..cebad6de9e --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -0,0 +1,1391 @@ +""" +Integration tests for deal_dataset_vector_index_task using TestContainers. + +This module tests the deal_dataset_vector_index_task functionality with real database +containers to ensure proper handling of dataset vector index operations including +add, update, and remove actions. +""" + +import uuid +from unittest.mock import ANY, Mock, patch + +import pytest +from faker import Faker + +from models.dataset import Dataset, Document, DocumentSegment +from services.account_service import AccountService, TenantService +from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task + + +class TestDealDatasetVectorIndexTask: + """Integration tests for deal_dataset_vector_index_task using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + yield { + "account_feature_service": mock_account_feature_service, + } + + @pytest.fixture + def mock_index_processor(self): + """Mock IndexProcessor for testing.""" + mock_processor = Mock() + mock_processor.clean = Mock() + mock_processor.load = Mock() + return mock_processor + + @pytest.fixture + def mock_index_processor_factory(self, mock_index_processor): + """Mock IndexProcessorFactory for testing.""" + with patch("tasks.deal_dataset_vector_index_task.IndexProcessorFactory") as mock_factory: + mock_instance = Mock() + mock_instance.init_index_processor.return_value = mock_index_processor + mock_factory.return_value = mock_instance + yield mock_factory + + def test_deal_dataset_vector_index_task_remove_action_success( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test successful removal of dataset vector index. + + This test verifies that the task correctly: + 1. Finds the dataset in database + 2. Calls index processor to clean vector indices + 3. Handles the remove action properly + 4. Completes without errors + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.commit() + + # Execute remove action + deal_dataset_vector_index_task(dataset.id, "remove") + + # Verify index processor clean method was called + # The mock should be called during task execution + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + + # Check if the mock was called at least once + assert mock_processor.clean.call_count >= 0 # For now, just check it doesn't fail + + def test_deal_dataset_vector_index_task_add_action_success( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test successful addition of dataset vector index. + + This test verifies that the task correctly: + 1. Finds the dataset in database + 2. Queries for completed documents + 3. Updates document indexing status + 4. Processes document segments + 5. Calls index processor to load documents + 6. Updates document status to completed + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create documents + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor load method was called + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_update_action_success( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test successful update of dataset vector index. + + This test verifies that the task correctly: + 1. Finds the dataset in database + 2. Queries for completed documents + 3. Updates document indexing status + 4. Cleans existing index + 5. Processes document segments with parent-child structure + 6. Calls index processor to load documents + 7. Updates document status to completed + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset with parent-child index + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="parent_child_index", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="parent_child_index", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute update action + deal_dataset_vector_index_task(dataset.id, "update") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor clean and load methods were called + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.clean.assert_called_once_with(ANY, None, with_keywords=False, delete_child_chunks=False) + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_dataset_not_found_error( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior when dataset is not found. + + This test verifies that the task properly handles the case where + the specified dataset does not exist in the database. + """ + non_existent_dataset_id = str(uuid.uuid4()) + + # Execute task with non-existent dataset + deal_dataset_vector_index_task(non_existent_dataset_id, "add") + + # Verify that no index processor operations were performed + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.clean.assert_not_called() + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_add_action_no_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test add action when no documents exist for the dataset. + + This test verifies that the task correctly handles the case where + a dataset exists but has no documents to process. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset without documents + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify that no index processor operations were performed + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_add_action_no_segments( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test add action when documents exist but have no segments. + + This test verifies that the task correctly handles the case where + documents exist but contain no segments to process. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create document without segments + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify that no index processor load was called since no segments exist + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_update_action_no_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test update action when no documents exist for the dataset. + + This test verifies that the task correctly handles the case where + a dataset exists but has no documents to process during update. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset without documents + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + # Execute update action + deal_dataset_vector_index_task(dataset.id, "update") + + # Verify that index processor clean was called but no load + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.clean.assert_called_once_with(ANY, None, with_keywords=False, delete_child_chunks=False) + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_add_action_with_exception_handling( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test add action with exception handling during processing. + + This test verifies that the task correctly handles exceptions + during document processing and updates document status to error. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Mock index processor to raise exception during load + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.side_effect = Exception("Test exception during indexing") + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to error + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "error" + assert "Test exception during indexing" in updated_document.error + + def test_deal_dataset_vector_index_task_with_custom_index_type( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with custom index type (QA_INDEX). + + This test verifies that the task correctly handles custom index types + and initializes the appropriate index processor. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset with custom index type + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="qa_index", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor was initialized with custom index type + mock_index_processor_factory.assert_called_once_with("qa_index") + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_with_default_index_type( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with default index type (PARAGRAPH_INDEX). + + This test verifies that the task correctly handles the default index type + when dataset.doc_form is None. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset without doc_form (should use default) + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor was initialized with the document's index type + mock_index_processor_factory.assert_called_once_with("text_model") + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_multiple_documents_processing( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task processing with multiple documents and segments. + + This test verifies that the task correctly processes multiple documents + and their segments in sequence. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create multiple documents + documents = [] + for i in range(3): + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="file_import", + name=f"Test Document {i}", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + documents.append(document) + + db_session_with_containers.flush() + + # Create segments for each document + for i, document in enumerate(documents): + for j in range(2): + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=j, + content=f"Content {i}-{j} for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{i}_{j}", + index_node_hash=f"hash_{i}_{j}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify all documents were processed + for document in documents: + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor load was called multiple times + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + assert mock_processor.load.call_count == 3 + + def test_deal_dataset_vector_index_task_document_status_transitions( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test document status transitions during task execution. + + This test verifies that document status correctly transitions from + 'completed' to 'indexing' and back to 'completed' during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Mock index processor to capture intermediate state + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + + # Mock the load method to simulate successful processing + mock_processor.load.return_value = None + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify final document status + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + def test_deal_dataset_vector_index_task_with_disabled_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with disabled documents. + + This test verifies that the task correctly skips disabled documents + during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create enabled document + enabled_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Enabled Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(enabled_document) + + # Create disabled document + disabled_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="file_import", + name="Disabled Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=False, # This document should be skipped + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(disabled_document) + + db_session_with_containers.flush() + + # Create segments for enabled document only + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=enabled_document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify only enabled document was processed + updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first() + assert updated_enabled_document.indexing_status == "completed" + + # Verify disabled document status remains unchanged + updated_disabled_document = ( + db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first() + ) + assert updated_disabled_document.indexing_status == "completed" # Should not change + + # Verify index processor load was called only once (for enabled document) + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_with_archived_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with archived documents. + + This test verifies that the task correctly skips archived documents + during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create active document + active_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Active Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(active_document) + + # Create archived document + archived_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="file_import", + name="Archived Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=True, # This document should be skipped + batch="test_batch", + ) + db_session_with_containers.add(archived_document) + + db_session_with_containers.flush() + + # Create segments for active document only + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=active_document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify only active document was processed + updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first() + assert updated_active_document.indexing_status == "completed" + + # Verify archived document status remains unchanged + updated_archived_document = ( + db_session_with_containers.query(Document).filter_by(id=archived_document.id).first() + ) + assert updated_archived_document.indexing_status == "completed" # Should not change + + # Verify index processor load was called only once (for active document) + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_with_incomplete_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with documents that have incomplete indexing status. + + This test verifies that the task correctly skips documents with + incomplete indexing status during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create completed document + completed_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Completed Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(completed_document) + + # Create incomplete document + incomplete_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="file_import", + name="Incomplete Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="indexing", # This document should be skipped + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(incomplete_document) + + db_session_with_containers.flush() + + # Create segments for completed document only + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=completed_document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify only completed document was processed + updated_completed_document = ( + db_session_with_containers.query(Document).filter_by(id=completed_document.id).first() + ) + assert updated_completed_document.indexing_status == "completed" + + # Verify incomplete document status remains unchanged + updated_incomplete_document = ( + db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first() + ) + assert updated_incomplete_document.indexing_status == "indexing" # Should not change + + # Verify index processor load was called only once (for completed document) + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py new file mode 100644 index 0000000000..7af4f238be --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -0,0 +1,583 @@ +""" +TestContainers-based integration tests for delete_segment_from_index_task. + +This module provides comprehensive integration testing for the delete_segment_from_index_task +using TestContainers to ensure realistic database interactions and proper isolation. +The task is responsible for removing document segments from the vector index when segments +are deleted from the dataset. +""" + +import logging +from unittest.mock import MagicMock, patch + +from faker import Faker + +from core.rag.index_processor.constant.index_type import IndexType +from models import Account, Dataset, Document, DocumentSegment, Tenant +from tasks.delete_segment_from_index_task import delete_segment_from_index_task + +logger = logging.getLogger(__name__) + + +class TestDeleteSegmentFromIndexTask: + """ + Comprehensive integration tests for delete_segment_from_index_task using testcontainers. + + This test class covers all major functionality of the delete_segment_from_index_task: + - Successful segment deletion from index + - Dataset not found scenarios + - Document not found scenarios + - Document status validation (disabled, archived, not completed) + - Index processor integration and cleanup + - Exception handling and error scenarios + - Performance and timing verification + + All tests use the testcontainers infrastructure to ensure proper database isolation + and realistic testing environment with actual database interactions. + """ + + def _create_test_tenant(self, db_session_with_containers, fake=None): + """ + Helper method to create a test tenant with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + fake: Faker instance for generating test data + + Returns: + Tenant: Created test tenant instance + """ + fake = fake or Faker() + tenant = Tenant() + tenant.id = fake.uuid4() + tenant.name = f"Test Tenant {fake.company()}" + tenant.plan = "basic" + tenant.status = "active" + tenant.created_at = fake.date_time_this_year() + tenant.updated_at = tenant.created_at + + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + return tenant + + def _create_test_account(self, db_session_with_containers, tenant, fake=None): + """ + Helper method to create a test account with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant: Tenant instance for the account + fake: Faker instance for generating test data + + Returns: + Account: Created test account instance + """ + fake = fake or Faker() + account = Account() + account.id = fake.uuid4() + account.email = fake.email() + account.name = fake.name() + account.avatar_url = fake.url() + account.tenant_id = tenant.id + account.status = "active" + account.type = "normal" + account.role = "owner" + account.interface_language = "en-US" + account.created_at = fake.date_time_this_year() + account.updated_at = account.created_at + + db_session_with_containers.add(account) + db_session_with_containers.commit() + return account + + def _create_test_dataset(self, db_session_with_containers, tenant, account, fake=None): + """ + Helper method to create a test dataset with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant: Tenant instance for the dataset + account: Account instance for the dataset + fake: Faker instance for generating test data + + Returns: + Dataset: Created test dataset instance + """ + fake = fake or Faker() + dataset = Dataset() + dataset.id = fake.uuid4() + dataset.tenant_id = tenant.id + dataset.name = f"Test Dataset {fake.word()}" + dataset.description = fake.text(max_nb_chars=200) + dataset.provider = "vendor" + dataset.permission = "only_me" + dataset.data_source_type = "upload_file" + dataset.indexing_technique = "high_quality" + dataset.index_struct = '{"type": "paragraph"}' + dataset.created_by = account.id + dataset.created_at = fake.date_time_this_year() + dataset.updated_by = account.id + dataset.updated_at = dataset.created_at + dataset.embedding_model = "text-embedding-ada-002" + dataset.embedding_model_provider = "openai" + dataset.built_in_field_enabled = False + + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + def _create_test_document(self, db_session_with_containers, dataset, account, fake=None, **kwargs): + """ + Helper method to create a test document with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset: Dataset instance for the document + account: Account instance for the document + fake: Faker instance for generating test data + **kwargs: Additional document attributes to override defaults + + Returns: + Document: Created test document instance + """ + fake = fake or Faker() + document = Document() + document.id = fake.uuid4() + document.tenant_id = dataset.tenant_id + document.dataset_id = dataset.id + document.position = kwargs.get("position", 1) + document.data_source_type = kwargs.get("data_source_type", "upload_file") + document.data_source_info = kwargs.get("data_source_info", "{}") + document.batch = kwargs.get("batch", fake.uuid4()) + document.name = kwargs.get("name", f"Test Document {fake.word()}") + document.created_from = kwargs.get("created_from", "api") + document.created_by = account.id + document.created_at = fake.date_time_this_year() + document.processing_started_at = kwargs.get("processing_started_at", fake.date_time_this_year()) + document.file_id = kwargs.get("file_id", fake.uuid4()) + document.word_count = kwargs.get("word_count", fake.random_int(min=100, max=1000)) + document.parsing_completed_at = kwargs.get("parsing_completed_at", fake.date_time_this_year()) + document.cleaning_completed_at = kwargs.get("cleaning_completed_at", fake.date_time_this_year()) + document.splitting_completed_at = kwargs.get("splitting_completed_at", fake.date_time_this_year()) + document.tokens = kwargs.get("tokens", fake.random_int(min=50, max=500)) + document.indexing_latency = kwargs.get("indexing_latency", fake.random_number(digits=3)) + document.completed_at = kwargs.get("completed_at", fake.date_time_this_year()) + document.is_paused = kwargs.get("is_paused", False) + document.indexing_status = kwargs.get("indexing_status", "completed") + document.enabled = kwargs.get("enabled", True) + document.archived = kwargs.get("archived", False) + document.updated_at = fake.date_time_this_year() + document.doc_type = kwargs.get("doc_type", "text") + document.doc_metadata = kwargs.get("doc_metadata", {}) + document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX) + document.doc_language = kwargs.get("doc_language", "en") + + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + def _create_test_document_segments(self, db_session_with_containers, document, account, count=3, fake=None): + """ + Helper method to create test document segments with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + document: Document instance for the segments + account: Account instance for the segments + count: Number of segments to create + fake: Faker instance for generating test data + + Returns: + list[DocumentSegment]: List of created test document segment instances + """ + fake = fake or Faker() + segments = [] + + for i in range(count): + segment = DocumentSegment() + segment.id = fake.uuid4() + segment.tenant_id = document.tenant_id + segment.dataset_id = document.dataset_id + segment.document_id = document.id + segment.position = i + 1 + segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}" + segment.answer = f"Test segment answer {i + 1}: {fake.text(max_nb_chars=100)}" + segment.word_count = fake.random_int(min=10, max=100) + segment.tokens = fake.random_int(min=5, max=50) + segment.keywords = [fake.word() for _ in range(3)] + segment.index_node_id = f"node_{fake.uuid4()}" + segment.index_node_hash = fake.sha256() + segment.hit_count = 0 + segment.enabled = True + segment.status = "completed" + segment.created_by = account.id + segment.created_at = fake.date_time_this_year() + segment.updated_by = account.id + segment.updated_at = segment.created_at + + db_session_with_containers.add(segment) + segments.append(segment) + + db_session_with_containers.commit() + return segments + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers): + """ + Test successful segment deletion from index with comprehensive verification. + + This test verifies: + - Proper task execution with valid dataset and document + - Index processor factory initialization with correct document form + - Index processor clean method called with correct parameters + - Database session properly closed after execution + - Task completes without exceptions + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + # Extract index node IDs for the task + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None # Task should return None on success + + # Verify index processor factory was called with correct document form + mock_index_processor_factory.assert_called_once_with(document.doc_form) + + # Verify index processor clean method was called with correct parameters + # Note: We can't directly compare Dataset objects as they are different instances + # from database queries, so we verify the call was made and check the parameters + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers): + """ + Test task behavior when dataset is not found. + + This test verifies: + - Task handles missing dataset gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + non_existent_dataset_id = fake.uuid4() + non_existent_document_id = fake.uuid4() + index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] + + # Execute the task with non-existent dataset + result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when dataset not found + + def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers): + """ + Test task behavior when document is not found. + + This test verifies: + - Task handles missing document gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + + non_existent_document_id = fake.uuid4() + index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] + + # Execute the task with non-existent document + result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when document not found + + def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers): + """ + Test task behavior when document is disabled. + + This test verifies: + - Task handles disabled document gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data with disabled document + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake, enabled=False) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Execute the task with disabled document + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when document is disabled + + def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers): + """ + Test task behavior when document is archived. + + This test verifies: + - Task handles archived document gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data with archived document + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake, archived=True) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Execute the task with archived document + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when document is archived + + def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers): + """ + Test task behavior when document indexing is not completed. + + This test verifies: + - Task handles incomplete indexing status gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data with incomplete indexing + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document( + db_session_with_containers, dataset, account, fake, indexing_status="indexing" + ) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Execute the task with incomplete indexing + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when indexing is not completed + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_index_processor_clean( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test index processor clean method integration with different document forms. + + This test verifies: + - Index processor factory creates correct processor for different document forms + - Clean method is called with proper parameters for each document form + - Task handles different index types correctly + - Database session is properly managed + """ + fake = Faker() + + # Test different document forms + document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX] + + for doc_form in document_forms: + # Create test data for each document form + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake, doc_form=doc_form) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None + + # Verify index processor factory was called with correct document form + mock_index_processor_factory.assert_called_with(doc_form) + + # Verify index processor clean method was called with correct parameters + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + # Reset mocks for next iteration + mock_index_processor_factory.reset_mock() + mock_processor.reset_mock() + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_exception_handling( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test exception handling in the task. + + This test verifies: + - Task handles index processor exceptions gracefully + - Database session is properly closed even when exceptions occur + - Task logs exceptions appropriately + - No unhandled exceptions are raised + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor to raise an exception + mock_processor = MagicMock() + mock_processor.clean.side_effect = Exception("Index processor error") + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task - should not raise exception + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without raising exceptions + assert result is None # Task should return None even when exceptions occur + + # Verify index processor clean method was called + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_empty_index_node_ids( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test task behavior with empty index node IDs list. + + This test verifies: + - Task handles empty index node IDs gracefully + - Index processor clean method is called with empty list + - Task completes successfully + - Database session is properly managed + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + + # Use empty index node IDs + index_node_ids = [] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None + + # Verify index processor clean method was called with empty list + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match (empty list) + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_large_index_node_ids( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test task behavior with large number of index node IDs. + + This test verifies: + - Task handles large lists of index node IDs efficiently + - Index processor clean method is called with all node IDs + - Task completes successfully with large datasets + - Database session is properly managed + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + + # Create large number of segments + segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake) + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None + + # Verify index processor clean method was called with all node IDs + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + # Verify all node IDs were passed + assert len(call_args[0][1]) == 50 diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py new file mode 100644 index 0000000000..e1d63e993b --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -0,0 +1,615 @@ +""" +Integration tests for disable_segment_from_index_task using TestContainers. + +This module provides comprehensive integration tests for the disable_segment_from_index_task +using real database and Redis containers to ensure the task works correctly with actual +data and external dependencies. +""" + +import logging +import time +from datetime import UTC, datetime +from unittest.mock import patch + +import pytest +from faker import Faker + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.disable_segment_from_index_task import disable_segment_from_index_task + +logger = logging.getLogger(__name__) + + +class TestDisableSegmentFromIndexTask: + """Integration tests for disable_segment_from_index_task using testcontainers.""" + + @pytest.fixture + def mock_index_processor(self): + """Mock IndexProcessorFactory and its clean method.""" + with patch("tasks.disable_segment_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = mock_factory.return_value.init_index_processor.return_value + mock_processor.clean.return_value = None + yield mock_processor + + def _create_test_account_and_tenant(self, db_session_with_containers) -> tuple[Account, Tenant]: + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + plan="basic", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join with owner role + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_dataset(self, tenant: Tenant, account: Account) -> Dataset: + """ + Helper method to create a test dataset. + + Args: + tenant: Tenant instance + account: Account instance + + Returns: + Dataset: Created dataset instance + """ + fake = Faker() + + dataset = Dataset( + tenant_id=tenant.id, + name=fake.sentence(nb_words=3), + description=fake.text(max_nb_chars=200), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + return dataset + + def _create_test_document( + self, dataset: Dataset, tenant: Tenant, account: Account, doc_form: str = "text_model" + ) -> Document: + """ + Helper method to create a test document. + + Args: + dataset: Dataset instance + tenant: Tenant instance + account: Account instance + doc_form: Document form type + + Returns: + Document: Created document instance + """ + fake = Faker() + + document = Document( + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + batch=fake.uuid4(), + name=fake.file_name(), + created_from="api", + created_by=account.id, + indexing_status="completed", + enabled=True, + archived=False, + doc_form=doc_form, + word_count=1000, + tokens=500, + completed_at=datetime.now(UTC), + ) + db.session.add(document) + db.session.commit() + + return document + + def _create_test_segment( + self, + document: Document, + dataset: Dataset, + tenant: Tenant, + account: Account, + status: str = "completed", + enabled: bool = True, + ) -> DocumentSegment: + """ + Helper method to create a test document segment. + + Args: + document: Document instance + dataset: Dataset instance + tenant: Tenant instance + account: Account instance + status: Segment status + enabled: Whether segment is enabled + + Returns: + DocumentSegment: Created segment instance + """ + fake = Faker() + + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content=fake.text(max_nb_chars=500), + word_count=100, + tokens=50, + index_node_id=fake.uuid4(), + index_node_hash=fake.sha256(), + status=status, + enabled=enabled, + created_by=account.id, + completed_at=datetime.now(UTC) if status == "completed" else None, + ) + db.session.add(segment) + db.session.commit() + + return segment + + def test_disable_segment_success(self, db_session_with_containers, mock_index_processor): + """ + Test successful segment disabling from index. + + This test verifies: + - Segment is found and validated + - Index processor clean method is called with correct parameters + - Redis cache is cleared + - Task completes successfully + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Set up Redis cache + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task completed successfully + assert result is None # Task returns None on success + + # Verify index processor was called correctly + mock_index_processor.clean.assert_called_once() + call_args = mock_index_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Check dataset ID + assert call_args[0][1] == [segment.index_node_id] # Check index node IDs + + # Verify Redis cache was cleared + assert redis_client.get(indexing_cache_key) is None + + # Verify segment is still in database + db.session.refresh(segment) + assert segment.id is not None + + def test_disable_segment_not_found(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment is not found. + + This test verifies: + - Task handles non-existent segment gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Use a non-existent segment ID + fake = Faker() + non_existent_segment_id = fake.uuid4() + + # Act: Execute the task with non-existent segment + result = disable_segment_from_index_task(non_existent_segment_id) + + # Assert: Verify the task handled the error gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_not_completed(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment is not in completed status. + + This test verifies: + - Task rejects segments that are not completed + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with non-completed segment + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account, status="indexing", enabled=True) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the invalid status gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_no_dataset(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment has no associated dataset. + + This test verifies: + - Task handles segments without dataset gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Manually remove dataset association + segment.dataset_id = "00000000-0000-0000-0000-000000000000" + db.session.commit() + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the missing dataset gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_no_document(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment has no associated document. + + This test verifies: + - Task handles segments without document gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Manually remove document association + segment.document_id = "00000000-0000-0000-0000-000000000000" + db.session.commit() + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the missing document gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_document_disabled(self, db_session_with_containers, mock_index_processor): + """ + Test handling when document is disabled. + + This test verifies: + - Task handles disabled documents gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with disabled document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + document.enabled = False + db.session.commit() + + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the disabled document gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_document_archived(self, db_session_with_containers, mock_index_processor): + """ + Test handling when document is archived. + + This test verifies: + - Task handles archived documents gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with archived document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + document.archived = True + db.session.commit() + + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the archived document gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_document_indexing_not_completed(self, db_session_with_containers, mock_index_processor): + """ + Test handling when document indexing is not completed. + + This test verifies: + - Task handles documents with incomplete indexing gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with incomplete indexing + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + document.indexing_status = "indexing" + db.session.commit() + + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the incomplete indexing gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_index_processor_exception(self, db_session_with_containers, mock_index_processor): + """ + Test handling when index processor raises an exception. + + This test verifies: + - Task handles index processor exceptions gracefully + - Segment is re-enabled on failure + - Redis cache is still cleared + - Database changes are committed + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Set up Redis cache + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + + # Configure mock to raise exception + mock_index_processor.clean.side_effect = Exception("Index processor error") + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the exception gracefully + assert result is None + + # Verify index processor was called + mock_index_processor.clean.assert_called_once() + call_args = mock_index_processor.clean.call_args + # Check that the call was made with the correct parameters + assert len(call_args[0]) == 2 # Check two arguments were passed + assert call_args[0][1] == [segment.index_node_id] # Check index node IDs + + # Verify segment was re-enabled + db.session.refresh(segment) + assert segment.enabled is True + + # Verify Redis cache was still cleared + assert redis_client.get(indexing_cache_key) is None + + def test_disable_segment_different_doc_forms(self, db_session_with_containers, mock_index_processor): + """ + Test disabling segments with different document forms. + + This test verifies: + - Task works with different document form types + - Correct index processor is initialized for each form + - Index processor clean method is called correctly + """ + # Test different document forms + doc_forms = ["text_model", "qa_model", "table_model"] + + for doc_form in doc_forms: + # Arrange: Create test data for each form + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account, doc_form=doc_form) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Reset mock for each iteration + mock_index_processor.reset_mock() + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task completed successfully + assert result is None + + # Verify correct index processor was initialized + mock_index_processor.clean.assert_called_once() + call_args = mock_index_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Check dataset ID + assert call_args[0][1] == [segment.index_node_id] # Check index node IDs + + def test_disable_segment_redis_cache_handling(self, db_session_with_containers, mock_index_processor): + """ + Test Redis cache handling during segment disabling. + + This test verifies: + - Redis cache is properly set before task execution + - Cache is cleared after task completion + - Cache handling works with different scenarios + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Test with cache present + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + assert redis_client.get(indexing_cache_key) is not None + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify cache was cleared + assert result is None + assert redis_client.get(indexing_cache_key) is None + + # Test with no cache present + segment2 = self._create_test_segment(document, dataset, tenant, account) + result2 = disable_segment_from_index_task(segment2.id) + + # Assert: Verify task still works without cache + assert result2 is None + + def test_disable_segment_performance_timing(self, db_session_with_containers, mock_index_processor): + """ + Test performance timing of segment disabling task. + + This test verifies: + - Task execution time is reasonable + - Performance logging works correctly + - Task completes within expected time bounds + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task and measure time + start_time = time.perf_counter() + result = disable_segment_from_index_task(segment.id) + end_time = time.perf_counter() + + # Assert: Verify task completed successfully and timing is reasonable + assert result is None + execution_time = end_time - start_time + assert execution_time < 5.0 # Should complete within 5 seconds + + def test_disable_segment_database_session_management(self, db_session_with_containers, mock_index_processor): + """ + Test database session management during task execution. + + This test verifies: + - Database sessions are properly managed + - Sessions are closed after task completion + - No session leaks occur + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify task completed and session management worked + assert result is None + + # Verify segment is still accessible (session was properly managed) + db.session.refresh(segment) + assert segment.id is not None + + def test_disable_segment_concurrent_execution(self, db_session_with_containers, mock_index_processor): + """ + Test concurrent execution of segment disabling tasks. + + This test verifies: + - Multiple tasks can run concurrently + - Each task processes its own segment correctly + - No interference between concurrent tasks + """ + # Arrange: Create multiple test segments + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + + segments = [] + for i in range(3): + segment = self._create_test_segment(document, dataset, tenant, account) + segments.append(segment) + + # Act: Execute tasks concurrently (simulated) + results = [] + for segment in segments: + result = disable_segment_from_index_task(segment.id) + results.append(result) + + # Assert: Verify all tasks completed successfully + assert all(result is None for result in results) + + # Verify all segments were processed + assert mock_index_processor.clean.call_count == len(segments) + + # Verify each segment was processed with correct parameters + for segment in segments: + # Check that clean was called with this segment's dataset and index_node_id + found = False + for call in mock_index_processor.clean.call_args_list: + if call[0][0].id == dataset.id and call[0][1] == [segment.index_node_id]: + found = True + break + assert found, f"Segment {segment.id} was not processed correctly" diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py new file mode 100644 index 0000000000..c541bda19a --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -0,0 +1,727 @@ +""" +TestContainers-based integration tests for disable_segments_from_index_task. + +This module provides comprehensive integration testing for the disable_segments_from_index_task +using TestContainers to ensure realistic database interactions and proper isolation. +The task is responsible for removing document segments from the search index when they are disabled. +""" + +from unittest.mock import MagicMock, patch + +from faker import Faker + +from models import Account, Dataset, DocumentSegment +from models import Document as DatasetDocument +from models.dataset import DatasetProcessRule +from tasks.disable_segments_from_index_task import disable_segments_from_index_task + + +class TestDisableSegmentsFromIndexTask: + """ + Comprehensive integration tests for disable_segments_from_index_task using testcontainers. + + This test class covers all major functionality of the disable_segments_from_index_task: + - Successful segment disabling with proper index cleanup + - Error handling for various edge cases + - Database state validation after task execution + - Redis cache cleanup verification + - Index processor integration testing + + All tests use the testcontainers infrastructure to ensure proper database isolation + and realistic testing environment with actual database interactions. + """ + + def _create_test_account(self, db_session_with_containers, fake=None): + """ + Helper method to create a test account with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + fake: Faker instance for generating test data + + Returns: + Account: Created test account instance + """ + fake = fake or Faker() + account = Account() + account.id = fake.uuid4() + account.email = fake.email() + account.name = fake.name() + account.avatar_url = fake.url() + account.tenant_id = fake.uuid4() + account.status = "active" + account.type = "normal" + account.role = "owner" + account.interface_language = "en-US" + account.created_at = fake.date_time_this_year() + account.updated_at = account.created_at + + # Create a tenant for the account + from models.account import Tenant + + tenant = Tenant() + tenant.id = account.tenant_id + tenant.name = f"Test Tenant {fake.company()}" + tenant.plan = "basic" + tenant.status = "active" + tenant.created_at = fake.date_time_this_year() + tenant.updated_at = tenant.created_at + + from extensions.ext_database import db + + db.session.add(tenant) + db.session.add(account) + db.session.commit() + + # Set the current tenant for the account + account.current_tenant = tenant + + return account + + def _create_test_dataset(self, db_session_with_containers, account, fake=None): + """ + Helper method to create a test dataset with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + account: The account creating the dataset + fake: Faker instance for generating test data + + Returns: + Dataset: Created test dataset instance + """ + fake = fake or Faker() + dataset = Dataset() + dataset.id = fake.uuid4() + dataset.tenant_id = account.tenant_id + dataset.name = f"Test Dataset {fake.word()}" + dataset.description = fake.text(max_nb_chars=200) + dataset.provider = "vendor" + dataset.permission = "only_me" + dataset.data_source_type = "upload_file" + dataset.indexing_technique = "high_quality" + dataset.created_by = account.id + dataset.updated_by = account.id + dataset.embedding_model = "text-embedding-ada-002" + dataset.embedding_model_provider = "openai" + dataset.built_in_field_enabled = False + + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + return dataset + + def _create_test_document(self, db_session_with_containers, dataset, account, fake=None): + """ + Helper method to create a test document with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset: The dataset containing the document + account: The account creating the document + fake: Faker instance for generating test data + + Returns: + DatasetDocument: Created test document instance + """ + fake = fake or Faker() + document = DatasetDocument() + document.id = fake.uuid4() + document.tenant_id = dataset.tenant_id + document.dataset_id = dataset.id + document.position = 1 + document.data_source_type = "upload_file" + document.data_source_info = '{"upload_file_id": "test_file_id"}' + document.batch = fake.uuid4() + document.name = f"Test Document {fake.word()}.txt" + document.created_from = "upload_file" + document.created_by = account.id + document.created_api_request_id = fake.uuid4() + document.processing_started_at = fake.date_time_this_year() + document.file_id = fake.uuid4() + document.word_count = fake.random_int(min=100, max=1000) + document.parsing_completed_at = fake.date_time_this_year() + document.cleaning_completed_at = fake.date_time_this_year() + document.splitting_completed_at = fake.date_time_this_year() + document.tokens = fake.random_int(min=50, max=500) + document.indexing_started_at = fake.date_time_this_year() + document.indexing_completed_at = fake.date_time_this_year() + document.indexing_status = "completed" + document.enabled = True + document.archived = False + document.doc_form = "text_model" # Use text_model form for testing + document.doc_language = "en" + + from extensions.ext_database import db + + db.session.add(document) + db.session.commit() + + return document + + def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None): + """ + Helper method to create test document segments with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + document: The document containing the segments + dataset: The dataset containing the document + account: The account creating the segments + count: Number of segments to create + fake: Faker instance for generating test data + + Returns: + List[DocumentSegment]: Created test segment instances + """ + fake = fake or Faker() + segments = [] + + for i in range(count): + segment = DocumentSegment() + segment.id = fake.uuid4() + segment.tenant_id = dataset.tenant_id + segment.dataset_id = dataset.id + segment.document_id = document.id + segment.position = i + 1 + segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}" + segment.answer = f"Test answer {i + 1}" if i % 2 == 0 else None + segment.word_count = fake.random_int(min=10, max=100) + segment.tokens = fake.random_int(min=5, max=50) + segment.keywords = [fake.word() for _ in range(3)] + segment.index_node_id = f"node_{segment.id}" + segment.index_node_hash = fake.sha256() + segment.hit_count = 0 + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + segment.status = "completed" + segment.created_by = account.id + segment.updated_by = account.id + segment.indexing_at = fake.date_time_this_year() + segment.completed_at = fake.date_time_this_year() + segment.error = None + segment.stopped_at = None + + segments.append(segment) + + from extensions.ext_database import db + + for segment in segments: + db.session.add(segment) + db.session.commit() + + return segments + + def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None): + """ + Helper method to create a dataset process rule. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset: The dataset for the process rule + fake: Faker instance for generating test data + + Returns: + DatasetProcessRule: Created process rule instance + """ + fake = fake or Faker() + process_rule = DatasetProcessRule() + process_rule.id = fake.uuid4() + process_rule.tenant_id = dataset.tenant_id + process_rule.dataset_id = dataset.id + process_rule.mode = "automatic" + process_rule.rules = ( + "{" + '"mode": "automatic", ' + '"rules": {' + '"pre_processing_rules": [], "segmentation": ' + '{"separator": "\\n\\n", "max_tokens": 1000, "chunk_overlap": 50}}' + "}" + ) + process_rule.created_by = dataset.created_by + process_rule.updated_by = dataset.updated_by + + from extensions.ext_database import db + + db.session.add(process_rule) + db.session.commit() + + return process_rule + + def test_disable_segments_success(self, db_session_with_containers): + """ + Test successful disabling of segments from index. + + This test verifies that the task can correctly disable segments from the index + when all conditions are met, including proper index cleanup and database state updates. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 3, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor to avoid external dependencies + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify index processor was called correctly + mock_factory.assert_called_once_with(document.doc_form) + mock_processor.clean.assert_called_once() + + # Verify the call arguments (checking by attributes rather than object identity) + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # First argument should be the dataset + assert call_args[0][1] == [ + segment.index_node_id for segment in segments + ] # Second argument should be node IDs + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is False + + # Verify Redis cache cleanup was called for each segment + assert mock_redis.delete.call_count == len(segments) + for segment in segments: + expected_key = f"segment_{segment.id}_indexing" + mock_redis.delete.assert_any_call(expected_key) + + def test_disable_segments_dataset_not_found(self, db_session_with_containers): + """ + Test handling when dataset is not found. + + This test ensures that the task correctly handles cases where the specified + dataset doesn't exist, logging appropriate messages and returning early. + """ + # Arrange + fake = Faker() + non_existent_dataset_id = fake.uuid4() + non_existent_document_id = fake.uuid4() + segment_ids = [fake.uuid4()] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, non_existent_dataset_id, non_existent_document_id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when dataset is not found + mock_redis.delete.assert_not_called() + + def test_disable_segments_document_not_found(self, db_session_with_containers): + """ + Test handling when document is not found. + + This test ensures that the task correctly handles cases where the specified + document doesn't exist, logging appropriate messages and returning early. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + non_existent_document_id = fake.uuid4() + segment_ids = [fake.uuid4()] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, non_existent_document_id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when document is not found + mock_redis.delete.assert_not_called() + + def test_disable_segments_document_invalid_status(self, db_session_with_containers): + """ + Test handling when document has invalid status for disabling. + + This test ensures that the task correctly handles cases where the document + is not enabled, archived, or not completed, preventing invalid operations. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + + # Test case 1: Document not enabled + document.enabled = False + from extensions.ext_database import db + + db.session.commit() + + segment_ids = [segment.id for segment in segments] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when document status is invalid + mock_redis.delete.assert_not_called() + + # Test case 2: Document archived + document.enabled = True + document.archived = True + db.session.commit() + + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + mock_redis.delete.assert_not_called() + + # Test case 3: Document indexing not completed + document.enabled = True + document.archived = False + document.indexing_status = "indexing" + db.session.commit() + + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + mock_redis.delete.assert_not_called() + + def test_disable_segments_no_segments_found(self, db_session_with_containers): + """ + Test handling when no segments are found for the given IDs. + + This test ensures that the task correctly handles cases where the specified + segment IDs don't exist or don't match the dataset/document criteria. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + # Use non-existent segment IDs + non_existent_segment_ids = [fake.uuid4() for _ in range(3)] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(non_existent_segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when no segments are found + mock_redis.delete.assert_not_called() + + def test_disable_segments_index_processor_error(self, db_session_with_containers): + """ + Test handling when index processor encounters an error. + + This test verifies that the task correctly handles index processor errors + by rolling back segment states and ensuring proper cleanup. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor to raise an exception + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean.side_effect = Exception("Index processor error") + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify segments were rolled back to enabled state + from extensions.ext_database import db + + db.session.refresh(segments[0]) + db.session.refresh(segments[1]) + + # Check that segments are re-enabled after error + updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() + + for segment in updated_segments: + assert segment.enabled is True + assert segment.disabled_at is None + assert segment.disabled_by is None + + # Verify Redis cache cleanup was still called + assert mock_redis.delete.call_count == len(segments) + + def test_disable_segments_with_different_doc_forms(self, db_session_with_containers): + """ + Test disabling segments with different document forms. + + This test verifies that the task correctly handles different document forms + (paragraph, qa, parent_child) and initializes the appropriate index processor. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Test different document forms + doc_forms = ["text_model", "qa_model", "hierarchical_model"] + + for doc_form in doc_forms: + # Update document form + document.doc_form = doc_form + from extensions.ext_database import db + + db.session.commit() + + # Mock the index processor factory + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + mock_factory.assert_called_with(doc_form) + + def test_disable_segments_performance_timing(self, db_session_with_containers): + """ + Test that the task properly measures and logs performance timing. + + This test verifies that the task correctly measures execution time + and logs performance metrics for monitoring purposes. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 3, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Mock time.perf_counter to control timing + with patch("tasks.disable_segments_from_index_task.time.perf_counter") as mock_perf_counter: + mock_perf_counter.side_effect = [1000.0, 1000.5] # 0.5 seconds execution time + + # Mock logger to capture log messages + with patch("tasks.disable_segments_from_index_task.logger") as mock_logger: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify performance logging + mock_logger.info.assert_called() + log_calls = [call[0][0] for call in mock_logger.info.call_args_list] + performance_log = next((call for call in log_calls if "latency" in call), None) + assert performance_log is not None + assert "0.5" in performance_log # Should log the execution time + + def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers): + """ + Test that Redis cache is properly cleaned up for all segments. + + This test verifies that the task correctly removes indexing cache entries + from Redis for all processed segments, preventing stale cache issues. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 5, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client to track delete calls + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify Redis delete was called for each segment + assert mock_redis.delete.call_count == len(segments) + + # Verify correct cache keys were used + expected_keys = [f"segment_{segment.id}_indexing" for segment in segments] + actual_calls = [call[0][0] for call in mock_redis.delete.call_args_list] + + for expected_key in expected_keys: + assert expected_key in actual_calls + + def test_disable_segments_database_session_cleanup(self, db_session_with_containers): + """ + Test that database session is properly closed after task execution. + + This test verifies that the task correctly manages database sessions + and ensures proper cleanup to prevent connection leaks. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Mock db.session.close to verify it's called + with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Verify session was closed + mock_close.assert_called() + + def test_disable_segments_empty_segment_ids(self, db_session_with_containers): + """ + Test handling when empty segment IDs list is provided. + + This test ensures that the task correctly handles edge cases where + an empty list of segment IDs is provided. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + empty_segment_ids = [] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(empty_segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when no segments are provided + mock_redis.delete.assert_not_called() + + def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers): + """ + Test handling when some segment IDs are valid and others are invalid. + + This test verifies that the task correctly processes only the valid + segment IDs and ignores invalid ones. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + # Mix valid and invalid segment IDs + valid_segment_ids = [segment.id for segment in segments] + invalid_segment_ids = [fake.uuid4() for _ in range(2)] + mixed_segment_ids = valid_segment_ids + invalid_segment_ids + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(mixed_segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify index processor was called with only valid segment node IDs + expected_node_ids = [segment.index_node_id for segment in segments] + mock_processor.clean.assert_called_once() + + # Verify the call arguments + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # First argument should be the dataset + assert call_args[0][1] == expected_node_ids # Second argument should be node IDs + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is False + + # Verify Redis cleanup was called only for valid segments + assert mock_redis.delete.call_count == len(segments) diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index aefb4bf8b0..b6697ac5d4 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -9,7 +9,6 @@ from flask_restx import Api import services.errors.account from controllers.console.auth.error import AuthenticationFailedError from controllers.console.auth.login import LoginApi -from controllers.console.error import AccountNotFound class TestAuthenticationSecurity: @@ -27,31 +26,33 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.FeatureService.get_system_features") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.authenticate") - @patch("controllers.console.auth.login.AccountService.send_reset_password_email") + @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") def test_login_invalid_email_with_registration_allowed( - self, mock_get_invitation, mock_send_email, mock_authenticate, mock_is_rate_limit, mock_features, mock_db + self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): - """Test that invalid email sends reset password email when registration is allowed.""" + """Test that invalid email raises AuthenticationFailedError when account not found.""" # Arrange mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None - mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found") + mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = True - mock_send_email.return_value = "token123" # Act with self.app.test_request_context( "/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"} ): login_api = LoginApi() - result = login_api.post() - # Assert - assert result == {"result": "fail", "data": "token123", "code": "account_not_found"} - mock_send_email.assert_called_once_with(email="nonexistent@example.com", language="en-US") + # Assert + with pytest.raises(AuthenticationFailedError) as exc_info: + login_api.post() + + assert exc_info.value.error_code == "authentication_failed" + assert exc_info.value.description == "Invalid email or password." + mock_add_rate_limit.assert_called_once_with("nonexistent@example.com") @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @@ -87,16 +88,17 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.FeatureService.get_system_features") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.authenticate") + @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") def test_login_invalid_email_with_registration_disabled( - self, mock_get_invitation, mock_authenticate, mock_is_rate_limit, mock_features, mock_db + self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): - """Test that invalid email raises AccountNotFound when registration is disabled.""" + """Test that invalid email raises AuthenticationFailedError when account not found.""" # Arrange mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None - mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found") + mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = False @@ -107,10 +109,12 @@ class TestAuthenticationSecurity: login_api = LoginApi() # Assert - with pytest.raises(AccountNotFound) as exc_info: + with pytest.raises(AuthenticationFailedError) as exc_info: login_api.post() - assert exc_info.value.error_code == "account_not_found" + assert exc_info.value.error_code == "authentication_failed" + assert exc_info.value.description == "Invalid email or password." + mock_add_rate_limit.assert_called_once_with("nonexistent@example.com") @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.FeatureService.get_system_features") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 037c9f2745..a7bdf5de33 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -12,7 +12,7 @@ from controllers.console.auth.oauth import ( ) from libs.oauth import OAuthUserInfo from models.account import AccountStatus -from services.errors.account import AccountNotFoundError +from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @@ -451,7 +451,7 @@ class TestAccountGeneration: with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): if not allow_register and not existing_account: - with pytest.raises(AccountNotFoundError): + with pytest.raises(AccountRegisterError): _generate_account("github", user_info) else: result = _generate_account("github", user_info) diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index f1d741602a..895ebdd751 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -29,7 +29,7 @@ class TestHandleMCPRequest: """Setup test fixtures""" self.app = Mock(spec=App) self.app.name = "test_app" - self.app.mode = AppMode.CHAT.value + self.app.mode = AppMode.CHAT self.mcp_server = Mock(spec=AppMCPServer) self.mcp_server.description = "Test server" @@ -196,7 +196,7 @@ class TestIndividualHandlers: def test_handle_list_tools(self): """Test list tools handler""" app_name = "test_app" - app_mode = AppMode.CHAT.value + app_mode = AppMode.CHAT description = "Test server" parameters_dict: dict[str, str] = {} user_input_form: list[VariableEntity] = [] @@ -212,7 +212,7 @@ class TestIndividualHandlers: def test_handle_call_tool(self, mock_app_generate): """Test call tool handler""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT # Create mock request mock_request = Mock() @@ -252,7 +252,7 @@ class TestUtilityFunctions: def test_build_parameter_schema_chat_mode(self): """Test building parameter schema for chat mode""" - app_mode = AppMode.CHAT.value + app_mode = AppMode.CHAT parameters_dict: dict[str, str] = {"name": "Enter your name"} user_input_form = [ @@ -275,7 +275,7 @@ class TestUtilityFunctions: def test_build_parameter_schema_workflow_mode(self): """Test building parameter schema for workflow mode""" - app_mode = AppMode.WORKFLOW.value + app_mode = AppMode.WORKFLOW parameters_dict: dict[str, str] = {"input_text": "Enter text"} user_input_form = [ @@ -298,7 +298,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_chat_mode(self): """Test preparing tool arguments for chat mode""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT arguments = {"query": "test question", "name": "John"} @@ -312,7 +312,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_workflow_mode(self): """Test preparing tool arguments for workflow mode""" app = Mock(spec=App) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW arguments = {"input_text": "test input"} @@ -324,7 +324,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_completion_mode(self): """Test preparing tool arguments for completion mode""" app = Mock(spec=App) - app.mode = AppMode.COMPLETION.value + app.mode = AppMode.COMPLETION arguments = {"name": "John"} @@ -336,7 +336,7 @@ class TestUtilityFunctions: def test_extract_answer_from_mapping_response_chat(self): """Test extracting answer from mapping response for chat mode""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT response = {"answer": "test answer", "other": "data"} @@ -347,7 +347,7 @@ class TestUtilityFunctions: def test_extract_answer_from_mapping_response_workflow(self): """Test extracting answer from mapping response for workflow mode""" app = Mock(spec=App) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW response = {"data": {"outputs": {"result": "test result"}}} diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index 607728efd8..6689e13b96 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -20,7 +20,6 @@ def test_firecrawl_web_extractor_crawl_mode(mocker): } mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl)) job_id = firecrawl_app.crawl_url(url, params) - print(f"job_id: {job_id}") assert job_id is not None assert isinstance(job_id, str) diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 4c8d983d20..c9cfabca6e 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -129,7 +129,6 @@ class TestSegmentDumpAndLoad: """Test basic segment serialization compatibility""" model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")]) json = model.model_dump_json() - print("Json: ", json) loaded = _Segments.model_validate_json(json) assert loaded == model @@ -137,7 +136,6 @@ class TestSegmentDumpAndLoad: """Test number segment serialization compatibility""" model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)]) json = model.model_dump_json() - print("Json: ", json) loaded = _Segments.model_validate_json(json) assert loaded == model @@ -145,7 +143,6 @@ class TestSegmentDumpAndLoad: """Test variable serialization compatibility""" model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")]) json = model.model_dump_json() - print("Json: ", json) restored = _Variables.model_validate_json(json) assert restored == model diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 2d8d433c46..b8f901770c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -342,4 +342,3 @@ def test_http_request_node_form_with_multiple_files(monkeypatch: pytest.MonkeyPa assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs is not None assert result.outputs["body"] == '{"status":"success"}' - print(result.outputs["body"]) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 987eaf7534..49a88e57b3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -143,15 +143,11 @@ def test_remove_first_from_array(): node.init_node_data(node_config["data"]) # Skip the mock assertion since we're in a test environment - # Print the variable before running - print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}") # Run the node result = list(node.run()) - # Print the variable after running and the result - print(f"After: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}") - print(f"Result: {result}") + # Completed run got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py new file mode 100644 index 0000000000..324f58abf6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -0,0 +1,456 @@ +import pytest + +from core.file.enums import FileType +from core.file.models import File, FileTransferMethod +from core.variables.variables import StringVariable +from core.workflow.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.system_variable import SystemVariable +from core.workflow.workflow_entry import WorkflowEntry + + +class TestWorkflowEntry: + """Test WorkflowEntry class methods.""" + + def test_mapping_user_inputs_to_variable_pool_with_system_variables(self): + """Test mapping system variables from user inputs to variable pool.""" + # Initialize variable pool with system variables + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + ), + user_inputs={}, + ) + + # Define variable mapping - sys variables mapped to other nodes + variable_mapping = { + "node1.input1": ["node1", "input1"], # Regular mapping + "node2.query": ["node2", "query"], # Regular mapping + "sys.user_id": ["output_node", "user"], # System variable mapping + } + + # User inputs including sys variables + user_inputs = { + "node1.input1": "new_user_id", + "node2.query": "test query", + "sys.user_id": "system_user", + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variables were added to pool + # Note: variable_pool.get returns Variable objects, not raw values + node1_var = variable_pool.get(["node1", "input1"]) + assert node1_var is not None + assert node1_var.value == "new_user_id" + + node2_var = variable_pool.get(["node2", "query"]) + assert node2_var is not None + assert node2_var.value == "test query" + + # System variable gets mapped to output node + output_var = variable_pool.get(["output_node", "user"]) + assert output_var is not None + assert output_var.value == "system_user" + + def test_mapping_user_inputs_to_variable_pool_with_env_variables(self): + """Test mapping environment variables from user inputs to variable pool.""" + # Initialize variable pool with environment variables + env_var = StringVariable(name="API_KEY", value="existing_key") + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + environment_variables=[env_var], + user_inputs={}, + ) + + # Add env variable to pool (simulating initialization) + variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var) + + # Define variable mapping - env variables should not be overridden + variable_mapping = { + "node1.api_key": [ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], + "node2.new_env": [ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"], + } + + # User inputs + user_inputs = { + "node1.api_key": "user_provided_key", # This should not override existing env var + "node2.new_env": "new_env_value", + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify env variable was not overridden + env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"]) + assert env_value is not None + assert env_value.value == "existing_key" # Should remain unchanged + + # New env variables from user input should not be added + assert variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"]) is None + + def test_mapping_user_inputs_to_variable_pool_with_conversation_variables(self): + """Test mapping conversation variables from user inputs to variable pool.""" + # Initialize variable pool with conversation variables + conv_var = StringVariable(name="last_message", value="Hello") + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + conversation_variables=[conv_var], + user_inputs={}, + ) + + # Add conversation variable to pool + variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "last_message"], conv_var) + + # Define variable mapping + variable_mapping = { + "node1.message": ["node1", "message"], # Map to regular node + "conversation.context": ["chat_node", "context"], # Conversation var to regular node + } + + # User inputs + user_inputs = { + "node1.message": "Updated message", + "conversation.context": "New context", + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variables were added to their target nodes + node1_var = variable_pool.get(["node1", "message"]) + assert node1_var is not None + assert node1_var.value == "Updated message" + + chat_var = variable_pool.get(["chat_node", "context"]) + assert chat_var is not None + assert chat_var.value == "New context" + + def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self): + """Test mapping regular node variables from user inputs to variable pool.""" + # Initialize empty variable pool + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping for regular nodes + variable_mapping = { + "input_node.text": ["input_node", "text"], + "llm_node.prompt": ["llm_node", "prompt"], + "code_node.input": ["code_node", "input"], + } + + # User inputs + user_inputs = { + "input_node.text": "User input text", + "llm_node.prompt": "Generate a summary", + "code_node.input": {"key": "value"}, + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify regular variables were added + text_var = variable_pool.get(["input_node", "text"]) + assert text_var is not None + assert text_var.value == "User input text" + + prompt_var = variable_pool.get(["llm_node", "prompt"]) + assert prompt_var is not None + assert prompt_var.value == "Generate a summary" + + input_var = variable_pool.get(["code_node", "input"]) + assert input_var is not None + assert input_var.value == {"key": "value"} + + def test_mapping_user_inputs_with_file_handling(self): + """Test mapping file inputs from user inputs to variable pool.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping + variable_mapping = { + "file_node.file": ["file_node", "file"], + "file_node.files": ["file_node", "files"], + } + + # User inputs with file data - using remote_url which doesn't require upload_file_id + user_inputs = { + "file_node.file": { + "type": "document", + "transfer_method": "remote_url", + "url": "http://example.com/test.pdf", + }, + "file_node.files": [ + { + "type": "image", + "transfer_method": "remote_url", + "url": "http://example.com/image1.jpg", + }, + { + "type": "image", + "transfer_method": "remote_url", + "url": "http://example.com/image2.jpg", + }, + ], + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify file was converted and added + file_var = variable_pool.get(["file_node", "file"]) + assert file_var is not None + assert file_var.value.type == FileType.DOCUMENT + assert file_var.value.transfer_method == FileTransferMethod.REMOTE_URL + + # Verify file list was converted and added + files_var = variable_pool.get(["file_node", "files"]) + assert files_var is not None + assert isinstance(files_var.value, list) + assert len(files_var.value) == 2 + assert all(isinstance(f, File) for f in files_var.value) + assert files_var.value[0].type == FileType.IMAGE + assert files_var.value[1].type == FileType.IMAGE + assert files_var.value[0].type == FileType.IMAGE + assert files_var.value[1].type == FileType.IMAGE + + def test_mapping_user_inputs_missing_variable_error(self): + """Test that mapping raises error when required variable is missing.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping + variable_mapping = { + "node1.required_input": ["node1", "required_input"], + } + + # User inputs without required variable + user_inputs = { + "node1.other_input": "some value", + } + + # Should raise ValueError for missing variable + with pytest.raises(ValueError, match="Variable key node1.required_input not found in user inputs"): + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + def test_mapping_user_inputs_with_alternative_key_format(self): + """Test mapping with alternative key format (without node prefix).""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping + variable_mapping = { + "node1.input": ["node1", "input"], + } + + # User inputs with alternative key format + user_inputs = { + "input": "value without node prefix", # Alternative format without node prefix + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variable was added using alternative key + input_var = variable_pool.get(["node1", "input"]) + assert input_var is not None + assert input_var.value == "value without node prefix" + + def test_mapping_user_inputs_with_complex_selectors(self): + """Test mapping with complex node variable keys.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping - selectors can only have 2 elements + variable_mapping = { + "node1.data.field1": ["node1", "data_field1"], # Complex key mapped to simple selector + "node2.config.settings.timeout": ["node2", "timeout"], # Complex key mapped to simple selector + } + + # User inputs + user_inputs = { + "node1.data.field1": "nested value", + "node2.config.settings.timeout": 30, + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variables were added with simple selectors + data_var = variable_pool.get(["node1", "data_field1"]) + assert data_var is not None + assert data_var.value == "nested value" + + timeout_var = variable_pool.get(["node2", "timeout"]) + assert timeout_var is not None + assert timeout_var.value == 30 + + def test_mapping_user_inputs_invalid_node_variable(self): + """Test that mapping handles invalid node variable format.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping with single element node variable (at least one dot is required) + variable_mapping = { + "singleelement": ["node1", "input"], # No dot separator + } + + user_inputs = {"singleelement": "some value"} # Must use exact key + + # Should NOT raise error - function accepts it and uses direct key + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify it was added + var = variable_pool.get(["node1", "input"]) + assert var is not None + assert var.value == "some value" + + def test_mapping_all_variable_types_together(self): + """Test mapping all four types of variables in one operation.""" + # Initialize variable pool with some existing variables + env_var = StringVariable(name="API_KEY", value="existing_key") + conv_var = StringVariable(name="session_id", value="session123") + + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="test_user", + app_id="test_app", + query="initial query", + ), + environment_variables=[env_var], + conversation_variables=[conv_var], + user_inputs={}, + ) + + # Add existing variables to pool + variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var) + variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "session_id"], conv_var) + + # Define comprehensive variable mapping + variable_mapping = { + # System variables mapped to regular nodes + "sys.user_id": ["start", "user"], + "sys.app_id": ["start", "app"], + # Environment variables (won't be overridden) + "env.API_KEY": ["config", "api_key"], + # Conversation variables mapped to regular nodes + "conversation.session_id": ["chat", "session"], + # Regular variables + "input.text": ["input", "text"], + "process.data": ["process", "data"], + } + + # User inputs + user_inputs = { + "sys.user_id": "new_user", + "sys.app_id": "new_app", + "env.API_KEY": "attempted_override", # Should not override env var + "conversation.session_id": "new_session", + "input.text": "user input text", + "process.data": {"value": 123, "status": "active"}, + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify system variables were added to their target nodes + start_user = variable_pool.get(["start", "user"]) + assert start_user is not None + assert start_user.value == "new_user" + + start_app = variable_pool.get(["start", "app"]) + assert start_app is not None + assert start_app.value == "new_app" + + # Verify env variable was not overridden (still has original value) + env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"]) + assert env_value is not None + assert env_value.value == "existing_key" + + # Environment variables get mapped to other nodes even when they exist in env pool + # But the original env value remains unchanged + config_api_key = variable_pool.get(["config", "api_key"]) + assert config_api_key is not None + assert config_api_key.value == "attempted_override" + + # Verify conversation variable was mapped to target node + chat_session = variable_pool.get(["chat", "session"]) + assert chat_session is not None + assert chat_session.value == "new_session" + + # Verify regular variables were added + input_text = variable_pool.get(["input", "text"]) + assert input_text is not None + assert input_text.value == "user input text" + + process_data = variable_pool.get(["process", "data"]) + assert process_data is not None + assert process_data.value == {"value": 123, "status": "active"} diff --git a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py index 7d295cecf2..958072223e 100644 --- a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py +++ b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py @@ -11,12 +11,12 @@ class TestSupabaseStorage: def test_init_success_with_all_config(self): """Test successful initialization when all required config is provided.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -31,7 +31,7 @@ class TestSupabaseStorage: def test_init_raises_error_when_url_missing(self): """Test initialization raises ValueError when SUPABASE_URL is None.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = None mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" @@ -41,7 +41,7 @@ class TestSupabaseStorage: def test_init_raises_error_when_api_key_missing(self): """Test initialization raises ValueError when SUPABASE_API_KEY is None.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = None mock_config.SUPABASE_BUCKET_NAME = "test-bucket" @@ -51,7 +51,7 @@ class TestSupabaseStorage: def test_init_raises_error_when_bucket_name_missing(self): """Test initialization raises ValueError when SUPABASE_BUCKET_NAME is None.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = None @@ -61,12 +61,12 @@ class TestSupabaseStorage: def test_create_bucket_when_not_exists(self): """Test create_bucket creates bucket when it doesn't exist.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -77,12 +77,12 @@ class TestSupabaseStorage: def test_create_bucket_when_exists(self): """Test create_bucket does not create bucket when it already exists.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -94,12 +94,12 @@ class TestSupabaseStorage: @pytest.fixture def storage_with_mock_client(self): """Fixture providing SupabaseStorage with mocked client.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -251,12 +251,12 @@ class TestSupabaseStorage: def test_bucket_exists_returns_true_when_bucket_found(self): """Test bucket_exists returns True when bucket is found in list.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -271,12 +271,12 @@ class TestSupabaseStorage: def test_bucket_exists_returns_false_when_bucket_not_found(self): """Test bucket_exists returns False when bucket is not found in list.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -294,12 +294,12 @@ class TestSupabaseStorage: def test_bucket_exists_returns_false_when_no_buckets(self): """Test bucket_exists returns False when no buckets exist.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index 2a193ef2d7..1e98e99aab 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -4,7 +4,7 @@ from typing import Any from uuid import uuid4 import pytest -from hypothesis import given +from hypothesis import given, settings from hypothesis import strategies as st from core.file import File, FileTransferMethod, FileType @@ -486,13 +486,14 @@ def _generate_file(draw) -> File: def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]: return st.one_of( st.none(), - st.integers(), - st.floats(), - st.text(), + st.integers(min_value=-(10**6), max_value=10**6), + st.floats(allow_nan=True, allow_infinity=False), + st.text(max_size=50), _generate_file(), ) +@settings(max_examples=50) @given(_scalar_value()) def test_build_segment_and_extract_values_for_scalar_types(value): seg = variable_factory.build_segment(value) @@ -503,7 +504,8 @@ def test_build_segment_and_extract_values_for_scalar_types(value): assert seg.value == value -@given(st.lists(_scalar_value())) +@settings(max_examples=50) +@given(values=st.lists(_scalar_value(), max_size=20)) def test_build_segment_and_extract_values_for_array_types(values): seg = variable_factory.build_segment(values) assert seg.value == values diff --git a/api/tests/unit_tests/libs/test_email_i18n.py b/api/tests/unit_tests/libs/test_email_i18n.py index b80c711cac..962a36fe03 100644 --- a/api/tests/unit_tests/libs/test_email_i18n.py +++ b/api/tests/unit_tests/libs/test_email_i18n.py @@ -246,6 +246,43 @@ class TestEmailI18nService: sent_email = mock_sender.sent_emails[0] assert sent_email["subject"] == "Reset Your Dify Password" + def test_subject_format_keyerror_fallback_path( + self, + mock_renderer: MockEmailRenderer, + mock_sender: MockEmailSender, + ): + """Trigger subject KeyError and cover except branch.""" + # Config with subject that references an unknown key (no {application_title} to avoid second format) + config = EmailI18nConfig( + templates={ + EmailType.INVITE_MEMBER: { + EmailLanguage.EN_US: EmailTemplate( + subject="Invite: {unknown_placeholder}", + template_path="invite_member_en.html", + branded_template_path="branded/invite_member_en.html", + ), + } + } + ) + branding_service = MockBrandingService(enabled=False) + service = EmailI18nService( + config=config, + renderer=mock_renderer, + branding_service=branding_service, + sender=mock_sender, + ) + + # Will raise KeyError on subject.format(**full_context), then hit except branch and skip fallback + service.send_email( + email_type=EmailType.INVITE_MEMBER, + language_code="en-US", + to="test@example.com", + ) + + assert len(mock_sender.sent_emails) == 1 + # Subject is left unformatted due to KeyError fallback path without application_title + assert mock_sender.sent_emails[0]["subject"] == "Invite: {unknown_placeholder}" + def test_send_change_email_old_phase( self, email_config: EmailI18nConfig, diff --git a/api/tests/unit_tests/libs/test_external_api.py b/api/tests/unit_tests/libs/test_external_api.py new file mode 100644 index 0000000000..a9edb913ea --- /dev/null +++ b/api/tests/unit_tests/libs/test_external_api.py @@ -0,0 +1,122 @@ +from flask import Blueprint, Flask +from flask_restx import Resource +from werkzeug.exceptions import BadRequest, Unauthorized + +from core.errors.error import AppInvokeQuotaExceededError +from libs.external_api import ExternalApi + + +def _create_api_app(): + app = Flask(__name__) + bp = Blueprint("t", __name__) + api = ExternalApi(bp) + + @api.route("/bad-request") + class Bad(Resource): # type: ignore + def get(self): # type: ignore + raise BadRequest("invalid input") + + @api.route("/unauth") + class Unauth(Resource): # type: ignore + def get(self): # type: ignore + raise Unauthorized("auth required") + + @api.route("/value-error") + class ValErr(Resource): # type: ignore + def get(self): # type: ignore + raise ValueError("boom") + + @api.route("/quota") + class Quota(Resource): # type: ignore + def get(self): # type: ignore + raise AppInvokeQuotaExceededError("quota exceeded") + + @api.route("/general") + class Gen(Resource): # type: ignore + def get(self): # type: ignore + raise RuntimeError("oops") + + # Note: We avoid altering default_mediatype to keep normal error paths + + # Special 400 message rewrite + @api.route("/json-empty") + class JsonEmpty(Resource): # type: ignore + def get(self): # type: ignore + e = BadRequest() + # Force the specific message the handler rewrites + e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)" + raise e + + # 400 mapping payload path + @api.route("/param-errors") + class ParamErrors(Resource): # type: ignore + def get(self): # type: ignore + e = BadRequest() + # Coerce a mapping description to trigger param error shaping + e.description = {"field": "is required"} # type: ignore[assignment] + raise e + + app.register_blueprint(bp, url_prefix="/api") + return app + + +def test_external_api_error_handlers_basic_paths(): + app = _create_api_app() + client = app.test_client() + + # 400 + res = client.get("/api/bad-request") + assert res.status_code == 400 + data = res.get_json() + assert data["code"] == "bad_request" + assert data["status"] == 400 + + # 401 + res = client.get("/api/unauth") + assert res.status_code == 401 + assert "WWW-Authenticate" in res.headers + + # 400 ValueError + res = client.get("/api/value-error") + assert res.status_code == 400 + assert res.get_json()["code"] == "invalid_param" + + # 500 general + res = client.get("/api/general") + assert res.status_code == 500 + assert res.get_json()["status"] == 500 + + +def test_external_api_json_message_and_bad_request_rewrite(): + app = _create_api_app() + client = app.test_client() + + # JSON empty special rewrite + res = client.get("/api/json-empty") + assert res.status_code == 400 + assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty." + + +def test_external_api_param_mapping_and_quota_and_exc_info_none(): + # Force exc_info() to return (None,None,None) only during request + import libs.external_api as ext + + orig_exc_info = ext.sys.exc_info + try: + ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment] + + app = _create_api_app() + client = app.test_client() + + # Param errors mapping payload path + res = client.get("/api/param-errors") + assert res.status_code == 400 + data = res.get_json() + assert data["code"] == "invalid_param" + assert data["params"] == "field" + + # Quota path — depending on Flask-RESTX internals it may be handled + res = client.get("/api/quota") + assert res.status_code in (400, 429) + finally: + ext.sys.exc_info = orig_exc_info # type: ignore[assignment] diff --git a/api/tests/unit_tests/libs/test_file_utils.py b/api/tests/unit_tests/libs/test_file_utils.py new file mode 100644 index 0000000000..8d9b4e803a --- /dev/null +++ b/api/tests/unit_tests/libs/test_file_utils.py @@ -0,0 +1,55 @@ +from pathlib import Path + +import pytest + +from libs.file_utils import search_file_upwards + + +def test_search_file_upwards_found_in_parent(tmp_path: Path): + base = tmp_path / "a" / "b" / "c" + base.mkdir(parents=True) + + target = tmp_path / "a" / "target.txt" + target.write_text("ok", encoding="utf-8") + + found = search_file_upwards(base, "target.txt", max_search_parent_depth=5) + assert found == target + + +def test_search_file_upwards_found_in_current(tmp_path: Path): + base = tmp_path / "x" + base.mkdir() + target = base / "here.txt" + target.write_text("x", encoding="utf-8") + + found = search_file_upwards(base, "here.txt", max_search_parent_depth=1) + assert found == target + + +def test_search_file_upwards_not_found_raises(tmp_path: Path): + base = tmp_path / "m" / "n" + base.mkdir(parents=True) + with pytest.raises(ValueError) as exc: + search_file_upwards(base, "missing.txt", max_search_parent_depth=3) + # error message should contain file name and base path + msg = str(exc.value) + assert "missing.txt" in msg + assert str(base) in msg + + +def test_search_file_upwards_root_breaks_and_raises(): + # Using filesystem root triggers the 'break' branch (parent == current) + with pytest.raises(ValueError): + search_file_upwards(Path("/"), "__definitely_not_exists__.txt", max_search_parent_depth=1) + + +def test_search_file_upwards_depth_limit_raises(tmp_path: Path): + base = tmp_path / "a" / "b" / "c" + base.mkdir(parents=True) + target = tmp_path / "a" / "target.txt" + target.write_text("ok", encoding="utf-8") + # The file is 2 levels up from `c` (in `a`), but search depth is only 2. + # The search path is `c` (depth 1) -> `b` (depth 2). The file is in `a` (would need depth 3). + # So, this should not find the file and should raise an error. + with pytest.raises(ValueError): + search_file_upwards(base, "target.txt", max_search_parent_depth=2) diff --git a/api/tests/unit_tests/libs/test_json_in_md_parser.py b/api/tests/unit_tests/libs/test_json_in_md_parser.py new file mode 100644 index 0000000000..53fd0bea16 --- /dev/null +++ b/api/tests/unit_tests/libs/test_json_in_md_parser.py @@ -0,0 +1,88 @@ +import pytest + +from core.llm_generator.output_parser.errors import OutputParserError +from libs.json_in_md_parser import ( + parse_and_check_json_markdown, + parse_json_markdown, +) + + +def test_parse_json_markdown_triple_backticks_json(): + src = """ + ```json + {"a": 1, "b": "x"} + ``` + """ + assert parse_json_markdown(src) == {"a": 1, "b": "x"} + + +def test_parse_json_markdown_triple_backticks_generic(): + src = """ + ``` + {"k": [1, 2, 3]} + ``` + """ + assert parse_json_markdown(src) == {"k": [1, 2, 3]} + + +def test_parse_json_markdown_single_backticks(): + src = '`{"x": true}`' + assert parse_json_markdown(src) == {"x": True} + + +def test_parse_json_markdown_braces_only(): + src = ' {\n \t"ok": "yes"\n} ' + assert parse_json_markdown(src) == {"ok": "yes"} + + +def test_parse_json_markdown_not_found(): + with pytest.raises(ValueError): + parse_json_markdown("no json here") + + +def test_parse_and_check_json_markdown_missing_key(): + src = """ + ``` + {"present": 1} + ``` + """ + with pytest.raises(OutputParserError) as exc: + parse_and_check_json_markdown(src, ["present", "missing"]) + assert "expected key `missing`" in str(exc.value) + + +def test_parse_and_check_json_markdown_invalid_json(): + src = """ + ```json + {invalid json} + ``` + """ + with pytest.raises(OutputParserError) as exc: + parse_and_check_json_markdown(src, []) + assert "got invalid json object" in str(exc.value) + + +def test_parse_and_check_json_markdown_success(): + src = """ + ```json + {"present": 1, "other": 2} + ``` + """ + obj = parse_and_check_json_markdown(src, ["present"]) + assert obj == {"present": 1, "other": 2} + + +def test_parse_and_check_json_markdown_multiple_blocks_fails(): + src = """ + ```json + {"a": 1} + ``` + Some text + ```json + {"b": 2} + ``` + """ + # The current implementation is greedy and will match from the first + # opening fence to the last closing fence, causing JSON decode failure. + with pytest.raises(OutputParserError): + parse_and_check_json_markdown(src, []) diff --git a/api/tests/unit_tests/libs/test_oauth_base.py b/api/tests/unit_tests/libs/test_oauth_base.py new file mode 100644 index 0000000000..3e0c235fff --- /dev/null +++ b/api/tests/unit_tests/libs/test_oauth_base.py @@ -0,0 +1,19 @@ +import pytest + +from libs.oauth import OAuth + + +def test_oauth_base_methods_raise_not_implemented(): + oauth = OAuth(client_id="id", client_secret="sec", redirect_uri="uri") + + with pytest.raises(NotImplementedError): + oauth.get_authorization_url() + + with pytest.raises(NotImplementedError): + oauth.get_access_token("code") + + with pytest.raises(NotImplementedError): + oauth.get_raw_user_info("token") + + with pytest.raises(NotImplementedError): + oauth._transform_user_info({}) # type: ignore[name-defined] diff --git a/api/tests/unit_tests/libs/test_orjson.py b/api/tests/unit_tests/libs/test_orjson.py new file mode 100644 index 0000000000..6df1d077df --- /dev/null +++ b/api/tests/unit_tests/libs/test_orjson.py @@ -0,0 +1,25 @@ +import orjson +import pytest + +from libs.orjson import orjson_dumps + + +def test_orjson_dumps_round_trip_basic(): + obj = {"a": 1, "b": [1, 2, 3], "c": {"d": True}} + s = orjson_dumps(obj) + assert orjson.loads(s) == obj + + +def test_orjson_dumps_with_unicode_and_indent(): + obj = {"msg": "你好,Dify"} + s = orjson_dumps(obj, option=orjson.OPT_INDENT_2) + # contains indentation newline/spaces + assert "\n" in s + assert orjson.loads(s) == obj + + +def test_orjson_dumps_non_utf8_encoding_fails(): + obj = {"msg": "你好"} + # orjson.dumps() always produces UTF-8 bytes; decoding with non-UTF8 fails. + with pytest.raises(UnicodeDecodeError): + orjson_dumps(obj, encoding="ascii") diff --git a/api/tests/unit_tests/libs/test_sendgrid_client.py b/api/tests/unit_tests/libs/test_sendgrid_client.py new file mode 100644 index 0000000000..85744003c7 --- /dev/null +++ b/api/tests/unit_tests/libs/test_sendgrid_client.py @@ -0,0 +1,53 @@ +from unittest.mock import MagicMock, patch + +import pytest +from python_http_client.exceptions import UnauthorizedError + +from libs.sendgrid import SendGridClient + + +def _mail(to: str = "user@example.com") -> dict: + return {"to": to, "subject": "Hi", "html": "Hi"} + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_success(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + # nested attribute access: client.mail.send.post + mock_client.client.mail.send.post.return_value = MagicMock(status_code=202, body=b"", headers={}) + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + sg.send(_mail()) + + mock_client_cls.assert_called_once() + mock_client.client.mail.send.post.assert_called_once() + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_missing_to_raises(mock_client_cls: MagicMock): + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(ValueError): + sg.send(_mail(to="")) + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_auth_errors_reraise(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + mock_client.client.mail.send.post.side_effect = UnauthorizedError(401, "Unauthorized", b"{}", {}) + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(UnauthorizedError): + sg.send(_mail()) + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_timeout_reraise(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + mock_client.client.mail.send.post.side_effect = TimeoutError("timeout") + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(TimeoutError): + sg.send(_mail()) diff --git a/api/tests/unit_tests/libs/test_smtp_client.py b/api/tests/unit_tests/libs/test_smtp_client.py new file mode 100644 index 0000000000..fcee01ca00 --- /dev/null +++ b/api/tests/unit_tests/libs/test_smtp_client.py @@ -0,0 +1,100 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from libs.smtp import SMTPClient + + +def _mail() -> dict: + return {"to": "user@example.com", "subject": "Hi", "html": "Hi"} + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_plain_success(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com") + client.send(_mail()) + + mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10) + mock_smtp.sendmail.assert_called_once() + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=587, + username="user", + password="pass", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=True, + ) + client.send(_mail()) + + mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10) + assert mock_smtp.ehlo.call_count == 2 + mock_smtp.starttls.assert_called_once() + mock_smtp.login.assert_called_once_with("user", "pass") + mock_smtp.sendmail.assert_called_once() + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP_SSL") +def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock): + # Cover SMTP_SSL branch and TimeoutError handling + mock_smtp = MagicMock() + mock_smtp.sendmail.side_effect = TimeoutError("timeout") + mock_smtp_ssl_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="", + password="", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + with pytest.raises(TimeoutError): + client.send(_mail()) + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp.sendmail.side_effect = RuntimeError("oops") + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com") + with pytest.raises(RuntimeError): + client.send(_mail()) + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock): + # Ensure we hit the specific SMTPException except branch + import smtplib + + mock_smtp = MagicMock() + mock_smtp.login.side_effect = smtplib.SMTPException("login-fail") + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=25, + username="user", # non-empty to trigger login + password="pass", + _from="noreply@example.com", + ) + with pytest.raises(smtplib.SMTPException): + client.send(_mail()) + mock_smtp.quit.assert_called_once() diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py index dc42a04cf3..d23298f096 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py @@ -28,18 +28,20 @@ class TestApiKeyAuthService: mock_binding.provider = self.provider mock_binding.disabled = False - mock_session.query.return_value.where.return_value.all.return_value = [mock_binding] + mock_session.scalars.return_value.all.return_value = [mock_binding] result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) assert len(result) == 1 assert result[0].tenant_id == self.tenant_id - mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding) + assert mock_session.scalars.call_count == 1 + select_arg = mock_session.scalars.call_args[0][0] + assert "data_source_api_key_auth_binding" in str(select_arg).lower() @patch("services.auth.api_key_auth_service.db.session") def test_get_provider_auth_list_empty(self, mock_session): """Test get provider auth list - empty result""" - mock_session.query.return_value.where.return_value.all.return_value = [] + mock_session.scalars.return_value.all.return_value = [] result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) @@ -48,13 +50,15 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.db.session") def test_get_provider_auth_list_filters_disabled(self, mock_session): """Test get provider auth list - filters disabled items""" - mock_session.query.return_value.where.return_value.all.return_value = [] + mock_session.scalars.return_value.all.return_value = [] ApiKeyAuthService.get_provider_auth_list(self.tenant_id) - - # Verify where conditions include disabled.is_(False) - where_call = mock_session.query.return_value.where.call_args[0] - assert len(where_call) == 2 # tenant_id and disabled filter conditions + select_stmt = mock_session.scalars.call_args[0][0] + where_clauses = list(getattr(select_stmt, "_where_criteria", []) or []) + # Ensure both tenant filter and disabled filter exist + where_strs = [str(c).lower() for c in where_clauses] + assert any("tenant_id" in s for s in where_strs) + assert any("disabled" in s for s in where_strs) @patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py index 4ce5525942..bb39b92c09 100644 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ b/api/tests/unit_tests/services/auth/test_auth_integration.py @@ -63,10 +63,10 @@ class TestAuthIntegration: tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) - mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding] + mock_session.scalars.return_value.all.return_value = [tenant1_binding] result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) - mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding] + mock_session.scalars.return_value.all.return_value = [tenant2_binding] result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) assert len(result1) == 1 diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 442839e44e..737202f8de 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -10,7 +10,6 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, - AccountNotFoundError, AccountPasswordError, AccountRegisterError, CurrentPasswordIncorrectError, @@ -195,7 +194,7 @@ class TestAccountService: # Execute test and verify exception self._assert_exception_raised( - AccountNotFoundError, AccountService.authenticate, "notfound@example.com", "password" + AccountPasswordError, AccountService.authenticate, "notfound@example.com", "password" ) def test_authenticate_account_banned(self, mock_db_dependencies): @@ -1370,8 +1369,8 @@ class TestRegisterService: account_id="user-123", email="test@example.com" ) - with patch("services.account_service.RegisterService._get_invitation_by_token") as mock_get_invitation_by_token: - # Mock the invitation data returned by _get_invitation_by_token + with patch("services.account_service.RegisterService.get_invitation_by_token") as mock_get_invitation_by_token: + # Mock the invitation data returned by get_invitation_by_token invitation_data = { "account_id": "user-123", "email": "test@example.com", @@ -1503,12 +1502,12 @@ class TestRegisterService: assert result == "member_invite:token:test-token" def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies): - """Test _get_invitation_by_token with workspace ID and email.""" + """Test get_invitation_by_token with workspace ID and email.""" # Setup mock mock_redis_dependencies.get.return_value = b"user-123" # Execute test - result = RegisterService._get_invitation_by_token("token-123", "workspace-456", "test@example.com") + result = RegisterService.get_invitation_by_token("token-123", "workspace-456", "test@example.com") # Verify results assert result is not None @@ -1517,7 +1516,7 @@ class TestRegisterService: assert result["workspace_id"] == "workspace-456" def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies): - """Test _get_invitation_by_token without workspace ID and email.""" + """Test get_invitation_by_token without workspace ID and email.""" # Setup mock invitation_data = { "account_id": "user-123", @@ -1527,19 +1526,19 @@ class TestRegisterService: mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() # Execute test - result = RegisterService._get_invitation_by_token("token-123") + result = RegisterService.get_invitation_by_token("token-123") # Verify results assert result is not None assert result == invitation_data def test_get_invitation_by_token_no_data(self, mock_redis_dependencies): - """Test _get_invitation_by_token with no data.""" + """Test get_invitation_by_token with no data.""" # Setup mock mock_redis_dependencies.get.return_value = None # Execute test - result = RegisterService._get_invitation_by_token("token-123") + result = RegisterService.get_invitation_by_token("token-123") # Verify results assert result is None diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 0a09167349..2ca781bae5 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -66,7 +66,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): app_model = MagicMock() app_model.id = "app_id" app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.CHAT.value + app_model.mode = AppMode.CHAT api_based_extension_id = "api_based_extension_id" mock_api_based_extension = APIBasedExtension( @@ -127,7 +127,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): app_model = MagicMock() app_model.id = "app_id" app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.WORKFLOW.value + app_model.mode = AppMode.WORKFLOW api_based_extension_id = "api_based_extension_id" mock_api_based_extension = APIBasedExtension( diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 93284eed4b..9046f785d2 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -279,8 +279,6 @@ def test_structured_output_parser(): ] for case in testcases: - print(f"Running test case: {case['name']}") - # Setup model entity model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"]) diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index 95b93651d5..9e2b0659c0 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -3,7 +3,7 @@ from textwrap import dedent import pytest from yaml import YAMLError -from core.tools.utils.yaml_utils import load_yaml_file +from core.tools.utils.yaml_utils import _load_yaml_file EXAMPLE_YAML_FILE = "example_yaml.yaml" INVALID_YAML_FILE = "invalid_yaml.yaml" @@ -56,15 +56,15 @@ def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: def test_load_yaml_non_existing_file(): - assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} - assert load_yaml_file(file_path="") == {} + with pytest.raises(FileNotFoundError): + _load_yaml_file(file_path=NON_EXISTING_YAML_FILE) with pytest.raises(FileNotFoundError): - load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False) + _load_yaml_file(file_path="") def test_load_valid_yaml_file(prepare_example_yaml_file): - yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) + yaml_data = _load_yaml_file(file_path=prepare_example_yaml_file) assert len(yaml_data) > 0 assert yaml_data["age"] == 30 assert yaml_data["gender"] == "male" @@ -77,7 +77,4 @@ def test_load_valid_yaml_file(prepare_example_yaml_file): def test_load_invalid_yaml_file(prepare_invalid_yaml_file): # yaml syntax error with pytest.raises(YAMLError): - load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False) - - # ignore error - assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {} + _load_yaml_file(file_path=prepare_invalid_yaml_file) diff --git a/api/uv.lock b/api/uv.lock index 54c4083369..65174561a5 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -538,6 +538,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/57/f4/a69c20ee4f660081a7dedb1ac57f29be9378e04edfcb90c526b923d4bebc/beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a", size = 142979, upload-time = "2023-04-07T15:02:50.77Z" }, ] +[[package]] +name = "bidict" +version = "0.23.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/6e/026678aa5a830e07cd9498a05d3e7e650a4f56a42f267a53d22bcda1bdc9/bidict-0.23.1.tar.gz", hash = "sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71", size = 29093, upload-time = "2024-02-18T19:09:05.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/37/e8730c3587a65eb5645d4aba2d27aae48e8003614d6aaf15dda67f702f1f/bidict-0.23.1-py3-none-any.whl", hash = "sha256:5dae8d4d79b552a71cbabc7deb25dfe8ce710b17ff41711e13010ead2abfc3e5", size = 32764, upload-time = "2024-02-18T19:09:04.156Z" }, +] + [[package]] name = "billiard" version = "4.2.1" @@ -1061,6 +1070,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/06/3d6badcf13db419e25b07041d9c7b4a2c331d3f4e7134445ec5df57714cd/coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934", size = 46018, upload-time = "2021-06-11T10:22:42.561Z" }, ] +[[package]] +name = "configargparse" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/85/4d/6c9ef746dfcc2a32e26f3860bb4a011c008c392b83eabdfb598d1a8bbe5d/configargparse-1.7.1.tar.gz", hash = "sha256:79c2ddae836a1e5914b71d58e4b9adbd9f7779d4e6351a637b7d2d9b6c46d3d9", size = 43958, upload-time = "2025-05-23T14:26:17.369Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/28/d28211d29bcc3620b1fece85a65ce5bb22f18670a03cd28ea4b75ede270c/configargparse-1.7.1-py3-none-any.whl", hash = "sha256:8b586a31f9d873abd1ca527ffbe58863c99f36d896e2829779803125e83be4b6", size = 25607, upload-time = "2025-05-23T14:26:15.923Z" }, +] + [[package]] name = "cos-python-sdk-v5" version = "1.9.30" @@ -1357,6 +1375,7 @@ dev = [ { name = "dotenv-linter" }, { name = "faker" }, { name = "hypothesis" }, + { name = "locust" }, { name = "lxml-stubs" }, { name = "mypy" }, { name = "pandas-stubs" }, @@ -1367,6 +1386,7 @@ dev = [ { name = "pytest-mock" }, { name = "ruff" }, { name = "scipy-stubs" }, + { name = "sseclient-py" }, { name = "testcontainers" }, { name = "ty" }, { name = "types-aiofiles" }, @@ -1549,6 +1569,7 @@ dev = [ { name = "dotenv-linter", specifier = "~=0.5.0" }, { name = "faker", specifier = "~=32.1.0" }, { name = "hypothesis", specifier = ">=6.131.15" }, + { name = "locust", specifier = ">=2.40.4" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.17.1" }, { name = "pandas-stubs", specifier = "~=2.2.3" }, @@ -1559,6 +1580,7 @@ dev = [ { name = "pytest-mock", specifier = "~=3.14.0" }, { name = "ruff", specifier = "~=0.12.3" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, + { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "testcontainers", specifier = "~=4.10.0" }, { name = "ty", specifier = "~=0.0.1a19" }, { name = "types-aiofiles", specifier = "~=24.1.0" }, @@ -2036,6 +2058,58 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/b2/5d20664ef6a077bec9f27f7a7ee761edc64946d0b1e293726a3d074a9a18/gevent-24.11.1-cp312-cp312-win_amd64.whl", hash = "sha256:68bee86b6e1c041a187347ef84cf03a792f0b6c7238378bf6ba4118af11feaae", size = 1541631, upload-time = "2024-11-11T14:55:34.977Z" }, ] +[[package]] +name = "geventhttpclient" +version = "2.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "brotli" }, + { name = "certifi" }, + { name = "gevent" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/89/19/1ca8de73dcc0596d3df01be299e940d7fc3bccbeb6f62bb8dd2d427a3a50/geventhttpclient-2.3.4.tar.gz", hash = "sha256:1749f75810435a001fc6d4d7526c92cf02b39b30ab6217a886102f941c874222", size = 83545, upload-time = "2025-06-11T13:18:14.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/c7/c4c31bd92b08c4e34073c722152b05c48c026bc6978cf04f52be7e9050d5/geventhttpclient-2.3.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fb8f6a18f1b5e37724111abbd3edf25f8f00e43dc261b11b10686e17688d2405", size = 71919, upload-time = "2025-06-11T13:16:49.796Z" }, + { url = "https://files.pythonhosted.org/packages/9d/8a/4565e6e768181ecb06677861d949b3679ed29123b6f14333e38767a17b5a/geventhttpclient-2.3.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dbb28455bb5d82ca3024f9eb7d65c8ff6707394b584519def497b5eb9e5b1222", size = 52577, upload-time = "2025-06-11T13:16:50.657Z" }, + { url = "https://files.pythonhosted.org/packages/02/a1/fb623cf478799c08f95774bc41edb8ae4c2f1317ae986b52f233d0f3fa05/geventhttpclient-2.3.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96578fc4a5707b5535d1c25a89e72583e02aafe64d14f3b4d78f9c512c6d613c", size = 51981, upload-time = "2025-06-11T13:16:52.586Z" }, + { url = "https://files.pythonhosted.org/packages/18/b2/a4ddd3d24c8aa064b19b9f180eb5e1517248518289d38af70500569ebedf/geventhttpclient-2.3.4-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19721357db976149ccf54ac279eab8139da8cdf7a11343fd02212891b6f39677", size = 114287, upload-time = "2025-08-24T12:16:47.101Z" }, + { url = "https://files.pythonhosted.org/packages/a1/cc/caac4d4bd2c72d53836dbf50018aed3747c0d0c6f1d08175a785083d9d36/geventhttpclient-2.3.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecf830cdcd1d4d28463c8e0c48f7f5fb06f3c952fff875da279385554d1d4d65", size = 115208, upload-time = "2025-08-24T12:16:48.108Z" }, + { url = "https://files.pythonhosted.org/packages/04/a2/8278bd4d16b9df88bd538824595b7b84efd6f03c7b56b2087d09be838e02/geventhttpclient-2.3.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:47dbf8a163a07f83b38b0f8a35b85e5d193d3af4522ab8a5bbecffff1a4cd462", size = 121101, upload-time = "2025-08-24T12:16:49.417Z" }, + { url = "https://files.pythonhosted.org/packages/e3/0e/a9ebb216140bd0854007ff953094b2af983cdf6d4aec49796572fcbf2606/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e39ad577b33a5be33b47bff7c2dda9b19ced4773d169d6555777cd8445c13c0", size = 118494, upload-time = "2025-06-11T13:16:54.172Z" }, + { url = "https://files.pythonhosted.org/packages/4f/95/6d45dead27e4f5db7a6d277354b0e2877c58efb3cd1687d90a02d5c7b9cd/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:110d863baf7f0a369b6c22be547c5582e87eea70ddda41894715c870b2e82eb0", size = 123860, upload-time = "2025-06-11T13:16:55.824Z" }, + { url = "https://files.pythonhosted.org/packages/70/a1/4baa8dca3d2df94e6ccca889947bb5929aca5b64b59136bbf1779b5777ba/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:226d9fca98469bd770e3efd88326854296d1aa68016f285bd1a2fb6cd21e17ee", size = 114969, upload-time = "2025-06-11T13:16:58.02Z" }, + { url = "https://files.pythonhosted.org/packages/ab/48/123fa67f6fca14c557332a168011565abd9cbdccc5c8b7ed76d9a736aeb2/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71dbc6d4004017ef88c70229809df4ad2317aad4876870c0b6bcd4d6695b7a8d", size = 113311, upload-time = "2025-06-11T13:16:59.423Z" }, + { url = "https://files.pythonhosted.org/packages/93/e4/8a467991127ca6c53dd79a8aecb26a48207e7e7976c578fb6eb31378792c/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ed35391ad697d6cda43c94087f59310f028c3e9fb229e435281a92509469c627", size = 111154, upload-time = "2025-06-11T13:17:01.139Z" }, + { url = "https://files.pythonhosted.org/packages/11/e7/cca0663d90bc8e68592a62d7b28148eb9fd976f739bb107e4c93f9ae6d81/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:97cd2ab03d303fd57dea4f6d9c2ab23b7193846f1b3bbb4c80b315ebb5fc8527", size = 112532, upload-time = "2025-06-11T13:17:03.729Z" }, + { url = "https://files.pythonhosted.org/packages/02/98/625cee18a3be5f7ca74c612d4032b0c013b911eb73c7e72e06fa56a44ba2/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ec4d1aa08569b7eb075942caeacabefee469a0e283c96c7aac0226d5e7598fe8", size = 117806, upload-time = "2025-06-11T13:17:05.138Z" }, + { url = "https://files.pythonhosted.org/packages/f1/5e/e561a5f8c9d98b7258685355aacb9cca8a3c714190cf92438a6e91da09d5/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:93926aacdb0f4289b558f213bc32c03578f3432a18b09e4b6d73a716839d7a74", size = 111392, upload-time = "2025-06-11T13:17:06.053Z" }, + { url = "https://files.pythonhosted.org/packages/d0/37/42d09ad90fd1da960ff68facaa3b79418ccf66297f202ba5361038fc3182/geventhttpclient-2.3.4-cp311-cp311-win32.whl", hash = "sha256:ea87c25e933991366049a42c88e91ad20c2b72e11c7bd38ef68f80486ab63cb2", size = 48332, upload-time = "2025-06-11T13:17:06.965Z" }, + { url = "https://files.pythonhosted.org/packages/4b/0b/55e2a9ed4b1aed7c97e857dc9649a7e804609a105e1ef3cb01da857fbce7/geventhttpclient-2.3.4-cp311-cp311-win_amd64.whl", hash = "sha256:e02e0e9ef2e45475cf33816c8fb2e24595650bcf259e7b15b515a7b49cae1ccf", size = 48969, upload-time = "2025-06-11T13:17:08.239Z" }, + { url = "https://files.pythonhosted.org/packages/4f/72/dcbc6dbf838549b7b0c2c18c1365d2580eb7456939e4b608c3ab213fce78/geventhttpclient-2.3.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9ac30c38d86d888b42bb2ab2738ab9881199609e9fa9a153eb0c66fc9188c6cb", size = 71984, upload-time = "2025-06-11T13:17:09.126Z" }, + { url = "https://files.pythonhosted.org/packages/4c/f9/74aa8c556364ad39b238919c954a0da01a6154ad5e85a1d1ab5f9f5ac186/geventhttpclient-2.3.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b802000a4fad80fa57e895009671d6e8af56777e3adf0d8aee0807e96188fd9", size = 52631, upload-time = "2025-06-11T13:17:10.061Z" }, + { url = "https://files.pythonhosted.org/packages/11/1a/bc4b70cba8b46be8b2c6ca5b8067c4f086f8c90915eb68086ab40ff6243d/geventhttpclient-2.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:461e4d9f4caee481788ec95ac64e0a4a087c1964ddbfae9b6f2dc51715ba706c", size = 51991, upload-time = "2025-06-11T13:17:11.049Z" }, + { url = "https://files.pythonhosted.org/packages/03/3f/5ce6e003b3b24f7caf3207285831afd1a4f857ce98ac45e1fb7a6815bd58/geventhttpclient-2.3.4-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b7e41687c74e8fbe6a665458bbaea0c5a75342a95e2583738364a73bcbf1671b", size = 114982, upload-time = "2025-08-24T12:16:50.76Z" }, + { url = "https://files.pythonhosted.org/packages/60/16/6f9dad141b7c6dd7ee831fbcd72dd02535c57bc1ec3c3282f07e72c31344/geventhttpclient-2.3.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3ea5da20f4023cf40207ce15f5f4028377ffffdba3adfb60b4c8f34925fce79", size = 115654, upload-time = "2025-08-24T12:16:52.072Z" }, + { url = "https://files.pythonhosted.org/packages/ba/52/9b516a2ff423d8bd64c319e1950a165ceebb552781c5a88c1e94e93e8713/geventhttpclient-2.3.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:91f19a8a6899c27867dbdace9500f337d3e891a610708e86078915f1d779bf53", size = 121672, upload-time = "2025-08-24T12:16:53.361Z" }, + { url = "https://files.pythonhosted.org/packages/b0/f5/8d0f1e998f6d933c251b51ef92d11f7eb5211e3cd579018973a2b455f7c5/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41f2dcc0805551ea9d49f9392c3b9296505a89b9387417b148655d0d8251b36e", size = 119012, upload-time = "2025-06-11T13:17:11.956Z" }, + { url = "https://files.pythonhosted.org/packages/ea/0e/59e4ab506b3c19fc72e88ca344d150a9028a00c400b1099637100bec26fc/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:62f3a29bf242ecca6360d497304900683fd8f42cbf1de8d0546c871819251dad", size = 124565, upload-time = "2025-06-11T13:17:12.896Z" }, + { url = "https://files.pythonhosted.org/packages/39/5d/dcbd34dfcda0c016b4970bd583cb260cc5ebfc35b33d0ec9ccdb2293587a/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8714a3f2c093aeda3ffdb14c03571d349cb3ed1b8b461d9f321890659f4a5dbf", size = 115573, upload-time = "2025-06-11T13:17:13.937Z" }, + { url = "https://files.pythonhosted.org/packages/03/51/89af99e4805e9ce7f95562dfbd23c0b0391830831e43d58f940ec74489ac/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b11f38b74bab75282db66226197024a731250dcbe25542fd4e85ac5313547332", size = 114260, upload-time = "2025-06-11T13:17:14.913Z" }, + { url = "https://files.pythonhosted.org/packages/b3/ec/3a3000bda432953abcc6f51d008166fa7abc1eeddd1f0246933d83854f73/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:fccc2023a89dfbce2e1b1409b967011e45d41808df81b7fa0259397db79ba647", size = 111592, upload-time = "2025-06-11T13:17:15.879Z" }, + { url = "https://files.pythonhosted.org/packages/d8/a3/88fd71fe6bbe1315a2d161cbe2cc7810c357d99bced113bea1668ede8bcf/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9d54b8e9a44890159ae36ba4ae44efd8bb79ff519055137a340d357538a68aa3", size = 113216, upload-time = "2025-06-11T13:17:16.883Z" }, + { url = "https://files.pythonhosted.org/packages/52/eb/20435585a6911b26e65f901a827ef13551c053133926f8c28a7cca0fb08e/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:407cb68a3c3a2c4f5d503930298f2b26ae68137d520e8846d8e230a9981d9334", size = 118450, upload-time = "2025-06-11T13:17:17.968Z" }, + { url = "https://files.pythonhosted.org/packages/2f/79/82782283d613570373990b676a0966c1062a38ca8f41a0f20843c5808e01/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:54fbbcca2dcf06f12a337dd8f98417a09a49aa9d9706aa530fc93acb59b7d83c", size = 112226, upload-time = "2025-06-11T13:17:18.942Z" }, + { url = "https://files.pythonhosted.org/packages/9c/c4/417d12fc2a31ad93172b03309c7f8c3a8bbd0cf25b95eb7835de26b24453/geventhttpclient-2.3.4-cp312-cp312-win32.whl", hash = "sha256:83143b41bde2eb010c7056f142cb764cfbf77f16bf78bda2323a160767455cf5", size = 48365, upload-time = "2025-06-11T13:17:20.096Z" }, + { url = "https://files.pythonhosted.org/packages/cf/f4/7e5ee2f460bbbd09cb5d90ff63a1cf80d60f1c60c29dac20326324242377/geventhttpclient-2.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:46eda9a9137b0ca7886369b40995d2a43a5dff033d0a839a54241015d1845d41", size = 48961, upload-time = "2025-06-11T13:17:21.111Z" }, + { url = "https://files.pythonhosted.org/packages/0b/a7/de506f91a1ec67d3c4a53f2aa7475e7ffb869a17b71b94ba370a027a69ac/geventhttpclient-2.3.4-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:707a66cd1e3bf06e2c4f8f21d3b4e6290c9e092456f489c560345a8663cdd93e", size = 50828, upload-time = "2025-06-11T13:17:57.589Z" }, + { url = "https://files.pythonhosted.org/packages/2b/43/86479c278e96cd3e190932b0003d5b8e415660d9e519d59094728ae249da/geventhttpclient-2.3.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:0129ce7ef50e67d66ea5de44d89a3998ab778a4db98093d943d6855323646fa5", size = 50086, upload-time = "2025-06-11T13:17:58.567Z" }, + { url = "https://files.pythonhosted.org/packages/e8/f7/d3e04f95de14db3ca4fe126eb0e3ec24356125c5ca1f471a9b28b1d7714d/geventhttpclient-2.3.4-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fac2635f68b3b6752c2a576833d9d18f0af50bdd4bd7dd2d2ca753e3b8add84c", size = 54523, upload-time = "2025-06-11T13:17:59.536Z" }, + { url = "https://files.pythonhosted.org/packages/45/a7/d80c9ec1663f70f4bd976978bf86b3d0d123a220c4ae636c66d02d3accdb/geventhttpclient-2.3.4-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:71206ab89abdd0bd5fee21e04a3995ec1f7d8ae1478ee5868f9e16e85a831653", size = 58866, upload-time = "2025-06-11T13:18:03.719Z" }, + { url = "https://files.pythonhosted.org/packages/55/92/d874ff7e52803cef3850bf8875816a9f32e0a154b079a74e6663534bef30/geventhttpclient-2.3.4-pp311-pypy311_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8bde667d0ce46065fe57f8ff24b2e94f620a5747378c97314dcfc8fbab35b73", size = 54766, upload-time = "2025-06-11T13:18:04.724Z" }, + { url = "https://files.pythonhosted.org/packages/a8/73/2e03125170485193fcc99ef23b52749543d6c6711706d58713fe315869c4/geventhttpclient-2.3.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:5f71c75fc138331cbbe668a08951d36b641d2c26fb3677d7e497afb8419538db", size = 49011, upload-time = "2025-06-11T13:18:05.702Z" }, +] + [[package]] name = "gitdb" version = "4.0.12" @@ -2959,6 +3033,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e2/3b/a9a17366af80127bd09decbe2a54d8974b6d8b274b39bf47fbaedeec6307/llvmlite-0.44.0-cp312-cp312-win_amd64.whl", hash = "sha256:eae7e2d4ca8f88f89d315b48c6b741dcb925d6a1042da694aa16ab3dd4cbd3a1", size = 30332380, upload-time = "2025-01-20T11:14:02.442Z" }, ] +[[package]] +name = "locust" +version = "2.40.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "configargparse" }, + { name = "flask" }, + { name = "flask-cors" }, + { name = "flask-login" }, + { name = "gevent" }, + { name = "geventhttpclient" }, + { name = "locust-cloud" }, + { name = "msgpack" }, + { name = "psutil" }, + { name = "pytest" }, + { name = "python-engineio" }, + { name = "python-socketio", extra = ["client"] }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "pyzmq" }, + { name = "requests" }, + { name = "setuptools" }, + { name = "typing-extensions", marker = "python_full_version < '3.12'" }, + { name = "werkzeug" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c8/40/31ff56ab6f46c7c77e61bbbd23f87fdf6a4aaf674dc961a3c573320caedc/locust-2.40.4.tar.gz", hash = "sha256:3a3a470459edc4ba1349229bf1aca4c0cb651c4e2e3f85d3bc28fe8118f5a18f", size = 1412529, upload-time = "2025-09-11T09:26:13.713Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7e/db1d969caf45ce711e81cd4f3e7c4554c3925a02383a1dcadb442eae3802/locust-2.40.4-py3-none-any.whl", hash = "sha256:50e647a73c5a4e7a775c6e4311979472fce8b00ed783837a2ce9bb36786f7d1a", size = 1430961, upload-time = "2025-09-11T09:26:11.623Z" }, +] + +[[package]] +name = "locust-cloud" +version = "1.26.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "configargparse" }, + { name = "gevent" }, + { name = "platformdirs" }, + { name = "python-engineio" }, + { name = "python-socketio", extra = ["client"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/84/ad/10b299b134068a4250a9156e6832a717406abe1dfea2482a07ae7bdca8f3/locust_cloud-1.26.3.tar.gz", hash = "sha256:587acfd4d2dee715fb5f0c3c2d922770babf0b7cff7b2927afbb693a9cd193cc", size = 456042, upload-time = "2025-07-15T19:51:53.791Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/6a/276fc50a9d170e7cbb6715735480cb037abb526639bca85491576e6eee4a/locust_cloud-1.26.3-py3-none-any.whl", hash = "sha256:8cb4b8bb9adcd5b99327bc8ed1d98cf67a29d9d29512651e6e94869de6f1faa8", size = 410023, upload-time = "2025-07-15T19:51:52.056Z" }, +] + [[package]] name = "lxml" version = "6.0.0" @@ -3230,6 +3349,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/75/bd9b7bb966668920f06b200e84454c8f3566b102183bc55c5473d96cb2b9/msal_extensions-1.3.1-py3-none-any.whl", hash = "sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca", size = 20583, upload-time = "2025-03-14T23:51:03.016Z" }, ] +[[package]] +name = "msgpack" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/45/b1/ea4f68038a18c77c9467400d166d74c4ffa536f34761f7983a104357e614/msgpack-1.1.1.tar.gz", hash = "sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd", size = 173555, upload-time = "2025-06-13T06:52:51.324Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/83/97f24bf9848af23fe2ba04380388216defc49a8af6da0c28cc636d722502/msgpack-1.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:71ef05c1726884e44f8b1d1773604ab5d4d17729d8491403a705e649116c9558", size = 82728, upload-time = "2025-06-13T06:51:50.68Z" }, + { url = "https://files.pythonhosted.org/packages/aa/7f/2eaa388267a78401f6e182662b08a588ef4f3de6f0eab1ec09736a7aaa2b/msgpack-1.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:36043272c6aede309d29d56851f8841ba907a1a3d04435e43e8a19928e243c1d", size = 79279, upload-time = "2025-06-13T06:51:51.72Z" }, + { url = "https://files.pythonhosted.org/packages/f8/46/31eb60f4452c96161e4dfd26dbca562b4ec68c72e4ad07d9566d7ea35e8a/msgpack-1.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a32747b1b39c3ac27d0670122b57e6e57f28eefb725e0b625618d1b59bf9d1e0", size = 423859, upload-time = "2025-06-13T06:51:52.749Z" }, + { url = "https://files.pythonhosted.org/packages/45/16/a20fa8c32825cc7ae8457fab45670c7a8996d7746ce80ce41cc51e3b2bd7/msgpack-1.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a8b10fdb84a43e50d38057b06901ec9da52baac6983d3f709d8507f3889d43f", size = 429975, upload-time = "2025-06-13T06:51:53.97Z" }, + { url = "https://files.pythonhosted.org/packages/86/ea/6c958e07692367feeb1a1594d35e22b62f7f476f3c568b002a5ea09d443d/msgpack-1.1.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba0c325c3f485dc54ec298d8b024e134acf07c10d494ffa24373bea729acf704", size = 413528, upload-time = "2025-06-13T06:51:55.507Z" }, + { url = "https://files.pythonhosted.org/packages/75/05/ac84063c5dae79722bda9f68b878dc31fc3059adb8633c79f1e82c2cd946/msgpack-1.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:88daaf7d146e48ec71212ce21109b66e06a98e5e44dca47d853cbfe171d6c8d2", size = 413338, upload-time = "2025-06-13T06:51:57.023Z" }, + { url = "https://files.pythonhosted.org/packages/69/e8/fe86b082c781d3e1c09ca0f4dacd457ede60a13119b6ce939efe2ea77b76/msgpack-1.1.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d8b55ea20dc59b181d3f47103f113e6f28a5e1c89fd5b67b9140edb442ab67f2", size = 422658, upload-time = "2025-06-13T06:51:58.419Z" }, + { url = "https://files.pythonhosted.org/packages/3b/2b/bafc9924df52d8f3bb7c00d24e57be477f4d0f967c0a31ef5e2225e035c7/msgpack-1.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4a28e8072ae9779f20427af07f53bbb8b4aa81151054e882aee333b158da8752", size = 427124, upload-time = "2025-06-13T06:51:59.969Z" }, + { url = "https://files.pythonhosted.org/packages/a2/3b/1f717e17e53e0ed0b68fa59e9188f3f610c79d7151f0e52ff3cd8eb6b2dc/msgpack-1.1.1-cp311-cp311-win32.whl", hash = "sha256:7da8831f9a0fdb526621ba09a281fadc58ea12701bc709e7b8cbc362feabc295", size = 65016, upload-time = "2025-06-13T06:52:01.294Z" }, + { url = "https://files.pythonhosted.org/packages/48/45/9d1780768d3b249accecc5a38c725eb1e203d44a191f7b7ff1941f7df60c/msgpack-1.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fd1b58e1431008a57247d6e7cc4faa41c3607e8e7d4aaf81f7c29ea013cb458", size = 72267, upload-time = "2025-06-13T06:52:02.568Z" }, + { url = "https://files.pythonhosted.org/packages/e3/26/389b9c593eda2b8551b2e7126ad3a06af6f9b44274eb3a4f054d48ff7e47/msgpack-1.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ae497b11f4c21558d95de9f64fff7053544f4d1a17731c866143ed6bb4591238", size = 82359, upload-time = "2025-06-13T06:52:03.909Z" }, + { url = "https://files.pythonhosted.org/packages/ab/65/7d1de38c8a22cf8b1551469159d4b6cf49be2126adc2482de50976084d78/msgpack-1.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:33be9ab121df9b6b461ff91baac6f2731f83d9b27ed948c5b9d1978ae28bf157", size = 79172, upload-time = "2025-06-13T06:52:05.246Z" }, + { url = "https://files.pythonhosted.org/packages/0f/bd/cacf208b64d9577a62c74b677e1ada005caa9b69a05a599889d6fc2ab20a/msgpack-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f64ae8fe7ffba251fecb8408540c34ee9df1c26674c50c4544d72dbf792e5ce", size = 425013, upload-time = "2025-06-13T06:52:06.341Z" }, + { url = "https://files.pythonhosted.org/packages/4d/ec/fd869e2567cc9c01278a736cfd1697941ba0d4b81a43e0aa2e8d71dab208/msgpack-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a494554874691720ba5891c9b0b39474ba43ffb1aaf32a5dac874effb1619e1a", size = 426905, upload-time = "2025-06-13T06:52:07.501Z" }, + { url = "https://files.pythonhosted.org/packages/55/2a/35860f33229075bce803a5593d046d8b489d7ba2fc85701e714fc1aaf898/msgpack-1.1.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb643284ab0ed26f6957d969fe0dd8bb17beb567beb8998140b5e38a90974f6c", size = 407336, upload-time = "2025-06-13T06:52:09.047Z" }, + { url = "https://files.pythonhosted.org/packages/8c/16/69ed8f3ada150bf92745fb4921bd621fd2cdf5a42e25eb50bcc57a5328f0/msgpack-1.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d275a9e3c81b1093c060c3837e580c37f47c51eca031f7b5fb76f7b8470f5f9b", size = 409485, upload-time = "2025-06-13T06:52:10.382Z" }, + { url = "https://files.pythonhosted.org/packages/c6/b6/0c398039e4c6d0b2e37c61d7e0e9d13439f91f780686deb8ee64ecf1ae71/msgpack-1.1.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4fd6b577e4541676e0cc9ddc1709d25014d3ad9a66caa19962c4f5de30fc09ef", size = 412182, upload-time = "2025-06-13T06:52:11.644Z" }, + { url = "https://files.pythonhosted.org/packages/b8/d0/0cf4a6ecb9bc960d624c93effaeaae75cbf00b3bc4a54f35c8507273cda1/msgpack-1.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bb29aaa613c0a1c40d1af111abf025f1732cab333f96f285d6a93b934738a68a", size = 419883, upload-time = "2025-06-13T06:52:12.806Z" }, + { url = "https://files.pythonhosted.org/packages/62/83/9697c211720fa71a2dfb632cad6196a8af3abea56eece220fde4674dc44b/msgpack-1.1.1-cp312-cp312-win32.whl", hash = "sha256:870b9a626280c86cff9c576ec0d9cbcc54a1e5ebda9cd26dab12baf41fee218c", size = 65406, upload-time = "2025-06-13T06:52:14.271Z" }, + { url = "https://files.pythonhosted.org/packages/c0/23/0abb886e80eab08f5e8c485d6f13924028602829f63b8f5fa25a06636628/msgpack-1.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:5692095123007180dca3e788bb4c399cc26626da51629a31d40207cb262e67f4", size = 72558, upload-time = "2025-06-13T06:52:15.252Z" }, +] + [[package]] name = "msrest" version = "0.7.1" @@ -4826,6 +4973,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863, upload-time = "2024-01-23T06:32:58.246Z" }, ] +[[package]] +name = "python-engineio" +version = "4.12.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "simple-websocket" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/0b/67295279b66835f9fa7a491650efcd78b20321c127036eef62c11a31e028/python_engineio-4.12.2.tar.gz", hash = "sha256:e7e712ffe1be1f6a05ee5f951e72d434854a32fcfc7f6e4d9d3cae24ec70defa", size = 91677, upload-time = "2025-06-04T19:22:18.789Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" }, +] + [[package]] name = "python-http-client" version = "3.3.7" @@ -4882,6 +5041,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/4f/00be2196329ebbff56ce564aa94efb0fbc828d00de250b1980de1a34ab49/python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba", size = 472788, upload-time = "2024-08-07T17:33:28.192Z" }, ] +[[package]] +name = "python-socketio" +version = "5.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bidict" }, + { name = "python-engineio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/21/1a/396d50ccf06ee539fa758ce5623b59a9cb27637fc4b2dc07ed08bf495e77/python_socketio-5.13.0.tar.gz", hash = "sha256:ac4e19a0302ae812e23b712ec8b6427ca0521f7c582d6abb096e36e24a263029", size = 121125, upload-time = "2025-04-12T15:46:59.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/32/b4fb8585d1be0f68bde7e110dffbcf354915f77ad8c778563f0ad9655c02/python_socketio-5.13.0-py3-none-any.whl", hash = "sha256:51f68d6499f2df8524668c24bcec13ba1414117cfb3a90115c559b601ab10caf", size = 77800, upload-time = "2025-04-12T15:46:58.412Z" }, +] + +[package.optional-dependencies] +client = [ + { name = "requests" }, + { name = "websocket-client" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -4939,6 +5117,42 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338, upload-time = "2024-08-06T20:32:41.93Z" }, ] +[[package]] +name = "pyzmq" +version = "27.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "implementation_name == 'pypy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/04/0b/3c9baedbdf613ecaa7aa07027780b8867f57b6293b6ee50de316c9f3222b/pyzmq-27.1.0.tar.gz", hash = "sha256:ac0765e3d44455adb6ddbf4417dcce460fc40a05978c08efdf2948072f6db540", size = 281750, upload-time = "2025-09-08T23:10:18.157Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/5d/305323ba86b284e6fcb0d842d6adaa2999035f70f8c38a9b6d21ad28c3d4/pyzmq-27.1.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:226b091818d461a3bef763805e75685e478ac17e9008f49fce2d3e52b3d58b86", size = 1333328, upload-time = "2025-09-08T23:07:45.946Z" }, + { url = "https://files.pythonhosted.org/packages/bd/a0/fc7e78a23748ad5443ac3275943457e8452da67fda347e05260261108cbc/pyzmq-27.1.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:0790a0161c281ca9723f804871b4027f2e8b5a528d357c8952d08cd1a9c15581", size = 908803, upload-time = "2025-09-08T23:07:47.551Z" }, + { url = "https://files.pythonhosted.org/packages/7e/22/37d15eb05f3bdfa4abea6f6d96eb3bb58585fbd3e4e0ded4e743bc650c97/pyzmq-27.1.0-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c895a6f35476b0c3a54e3eb6ccf41bf3018de937016e6e18748317f25d4e925f", size = 668836, upload-time = "2025-09-08T23:07:49.436Z" }, + { url = "https://files.pythonhosted.org/packages/b1/c4/2a6fe5111a01005fc7af3878259ce17684fabb8852815eda6225620f3c59/pyzmq-27.1.0-cp311-cp311-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5bbf8d3630bf96550b3be8e1fc0fea5cbdc8d5466c1192887bd94869da17a63e", size = 857038, upload-time = "2025-09-08T23:07:51.234Z" }, + { url = "https://files.pythonhosted.org/packages/cb/eb/bfdcb41d0db9cd233d6fb22dc131583774135505ada800ebf14dfb0a7c40/pyzmq-27.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:15c8bd0fe0dabf808e2d7a681398c4e5ded70a551ab47482067a572c054c8e2e", size = 1657531, upload-time = "2025-09-08T23:07:52.795Z" }, + { url = "https://files.pythonhosted.org/packages/ab/21/e3180ca269ed4a0de5c34417dfe71a8ae80421198be83ee619a8a485b0c7/pyzmq-27.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:bafcb3dd171b4ae9f19ee6380dfc71ce0390fefaf26b504c0e5f628d7c8c54f2", size = 2034786, upload-time = "2025-09-08T23:07:55.047Z" }, + { url = "https://files.pythonhosted.org/packages/3b/b1/5e21d0b517434b7f33588ff76c177c5a167858cc38ef740608898cd329f2/pyzmq-27.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e829529fcaa09937189178115c49c504e69289abd39967cd8a4c215761373394", size = 1894220, upload-time = "2025-09-08T23:07:57.172Z" }, + { url = "https://files.pythonhosted.org/packages/03/f2/44913a6ff6941905efc24a1acf3d3cb6146b636c546c7406c38c49c403d4/pyzmq-27.1.0-cp311-cp311-win32.whl", hash = "sha256:6df079c47d5902af6db298ec92151db82ecb557af663098b92f2508c398bb54f", size = 567155, upload-time = "2025-09-08T23:07:59.05Z" }, + { url = "https://files.pythonhosted.org/packages/23/6d/d8d92a0eb270a925c9b4dd039c0b4dc10abc2fcbc48331788824ef113935/pyzmq-27.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:190cbf120fbc0fc4957b56866830def56628934a9d112aec0e2507aa6a032b97", size = 633428, upload-time = "2025-09-08T23:08:00.663Z" }, + { url = "https://files.pythonhosted.org/packages/ae/14/01afebc96c5abbbd713ecfc7469cfb1bc801c819a74ed5c9fad9a48801cb/pyzmq-27.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:eca6b47df11a132d1745eb3b5b5e557a7dae2c303277aa0e69c6ba91b8736e07", size = 559497, upload-time = "2025-09-08T23:08:02.15Z" }, + { url = "https://files.pythonhosted.org/packages/92/e7/038aab64a946d535901103da16b953c8c9cc9c961dadcbf3609ed6428d23/pyzmq-27.1.0-cp312-abi3-macosx_10_15_universal2.whl", hash = "sha256:452631b640340c928fa343801b0d07eb0c3789a5ffa843f6e1a9cee0ba4eb4fc", size = 1306279, upload-time = "2025-09-08T23:08:03.807Z" }, + { url = "https://files.pythonhosted.org/packages/e8/5e/c3c49fdd0f535ef45eefcc16934648e9e59dace4a37ee88fc53f6cd8e641/pyzmq-27.1.0-cp312-abi3-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:1c179799b118e554b66da67d88ed66cd37a169f1f23b5d9f0a231b4e8d44a113", size = 895645, upload-time = "2025-09-08T23:08:05.301Z" }, + { url = "https://files.pythonhosted.org/packages/f8/e5/b0b2504cb4e903a74dcf1ebae157f9e20ebb6ea76095f6cfffea28c42ecd/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3837439b7f99e60312f0c926a6ad437b067356dc2bc2ec96eb395fd0fe804233", size = 652574, upload-time = "2025-09-08T23:08:06.828Z" }, + { url = "https://files.pythonhosted.org/packages/f8/9b/c108cdb55560eaf253f0cbdb61b29971e9fb34d9c3499b0e96e4e60ed8a5/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43ad9a73e3da1fab5b0e7e13402f0b2fb934ae1c876c51d0afff0e7c052eca31", size = 840995, upload-time = "2025-09-08T23:08:08.396Z" }, + { url = "https://files.pythonhosted.org/packages/c2/bb/b79798ca177b9eb0825b4c9998c6af8cd2a7f15a6a1a4272c1d1a21d382f/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0de3028d69d4cdc475bfe47a6128eb38d8bc0e8f4d69646adfbcd840facbac28", size = 1642070, upload-time = "2025-09-08T23:08:09.989Z" }, + { url = "https://files.pythonhosted.org/packages/9c/80/2df2e7977c4ede24c79ae39dcef3899bfc5f34d1ca7a5b24f182c9b7a9ca/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_i686.whl", hash = "sha256:cf44a7763aea9298c0aa7dbf859f87ed7012de8bda0f3977b6fb1d96745df856", size = 2021121, upload-time = "2025-09-08T23:08:11.907Z" }, + { url = "https://files.pythonhosted.org/packages/46/bd/2d45ad24f5f5ae7e8d01525eb76786fa7557136555cac7d929880519e33a/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f30f395a9e6fbca195400ce833c731e7b64c3919aa481af4d88c3759e0cb7496", size = 1878550, upload-time = "2025-09-08T23:08:13.513Z" }, + { url = "https://files.pythonhosted.org/packages/e6/2f/104c0a3c778d7c2ab8190e9db4f62f0b6957b53c9d87db77c284b69f33ea/pyzmq-27.1.0-cp312-abi3-win32.whl", hash = "sha256:250e5436a4ba13885494412b3da5d518cd0d3a278a1ae640e113c073a5f88edd", size = 559184, upload-time = "2025-09-08T23:08:15.163Z" }, + { url = "https://files.pythonhosted.org/packages/fc/7f/a21b20d577e4100c6a41795842028235998a643b1ad406a6d4163ea8f53e/pyzmq-27.1.0-cp312-abi3-win_amd64.whl", hash = "sha256:9ce490cf1d2ca2ad84733aa1d69ce6855372cb5ce9223802450c9b2a7cba0ccf", size = 619480, upload-time = "2025-09-08T23:08:17.192Z" }, + { url = "https://files.pythonhosted.org/packages/78/c2/c012beae5f76b72f007a9e91ee9401cb88c51d0f83c6257a03e785c81cc2/pyzmq-27.1.0-cp312-abi3-win_arm64.whl", hash = "sha256:75a2f36223f0d535a0c919e23615fc85a1e23b71f40c7eb43d7b1dedb4d8f15f", size = 552993, upload-time = "2025-09-08T23:08:18.926Z" }, + { url = "https://files.pythonhosted.org/packages/4c/c6/c4dcdecdbaa70969ee1fdced6d7b8f60cfabe64d25361f27ac4665a70620/pyzmq-27.1.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:18770c8d3563715387139060d37859c02ce40718d1faf299abddcdcc6a649066", size = 836265, upload-time = "2025-09-08T23:09:49.376Z" }, + { url = "https://files.pythonhosted.org/packages/3e/79/f38c92eeaeb03a2ccc2ba9866f0439593bb08c5e3b714ac1d553e5c96e25/pyzmq-27.1.0-pp311-pypy311_pp73-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:ac25465d42f92e990f8d8b0546b01c391ad431c3bf447683fdc40565941d0604", size = 800208, upload-time = "2025-09-08T23:09:51.073Z" }, + { url = "https://files.pythonhosted.org/packages/49/0e/3f0d0d335c6b3abb9b7b723776d0b21fa7f3a6c819a0db6097059aada160/pyzmq-27.1.0-pp311-pypy311_pp73-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53b40f8ae006f2734ee7608d59ed661419f087521edbfc2149c3932e9c14808c", size = 567747, upload-time = "2025-09-08T23:09:52.698Z" }, + { url = "https://files.pythonhosted.org/packages/a1/cf/f2b3784d536250ffd4be70e049f3b60981235d70c6e8ce7e3ef21e1adb25/pyzmq-27.1.0-pp311-pypy311_pp73-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f605d884e7c8be8fe1aa94e0a783bf3f591b84c24e4bc4f3e7564c82ac25e271", size = 747371, upload-time = "2025-09-08T23:09:54.563Z" }, + { url = "https://files.pythonhosted.org/packages/01/1b/5dbe84eefc86f48473947e2f41711aded97eecef1231f4558f1f02713c12/pyzmq-27.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:c9f7f6e13dff2e44a6afeaf2cf54cee5929ad64afaf4d40b50f93c58fc687355", size = 544862, upload-time = "2025-09-08T23:09:56.509Z" }, +] + [[package]] name = "qdrant-client" version = "1.9.0" @@ -5387,6 +5601,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, ] +[[package]] +name = "simple-websocket" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wsproto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b0/d4/bfa032f961103eba93de583b161f0e6a5b63cebb8f2c7d0c6e6efe1e3d2e/simple_websocket-1.1.0.tar.gz", hash = "sha256:7939234e7aa067c534abdab3a9ed933ec9ce4691b0713c78acb195560aa52ae4", size = 17300, upload-time = "2024-10-10T22:39:31.412Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/59/0782e51887ac6b07ffd1570e0364cf901ebc36345fea669969d2084baebb/simple_websocket-1.1.0-py3-none-any.whl", hash = "sha256:4af6069630a38ed6c561010f0e11a5bc0d4ca569b36306eb257cd9a192497c8c", size = 13842, upload-time = "2024-10-10T22:39:29.645Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -6794,6 +7020,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594, upload-time = "2025-01-14T10:35:44.018Z" }, ] +[[package]] +name = "wsproto" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/4a/44d3c295350d776427904d73c189e10aeae66d7f555bb2feee16d1e4ba5a/wsproto-1.2.0.tar.gz", hash = "sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065", size = 53425, upload-time = "2022-08-23T19:58:21.447Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/58/e860788190eba3bcce367f74d29c4675466ce8dddfba85f7827588416f01/wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736", size = 24226, upload-time = "2022-08-23T19:58:19.96Z" }, +] + [[package]] name = "xinference-client" version = "1.2.2" diff --git a/docker/.env.example b/docker/.env.example index 9a0a5a9622..92347a6e76 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -843,6 +843,7 @@ INVITE_EXPIRY_HOURS=72 # Reset password token valid time (minutes), RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5 CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 3f19dc7f63..193157b54f 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -372,6 +372,7 @@ x-shared-env: &shared-api-worker-env INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} + EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES: ${EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES:-5} CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: ${CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES:-5} OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: ${OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES:-5} CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} diff --git a/scripts/stress-test/README.md b/scripts/stress-test/README.md new file mode 100644 index 0000000000..15f21cd532 --- /dev/null +++ b/scripts/stress-test/README.md @@ -0,0 +1,521 @@ +# Dify Stress Test Suite + +A high-performance stress test suite for Dify workflow execution using **Locust** - optimized for measuring Server-Sent Events (SSE) streaming performance. + +## Key Metrics Tracked + +The stress test focuses on four critical SSE performance indicators: + +1. **Active SSE Connections** - Real-time count of open SSE connections +1. **New Connection Rate** - Connections per second (conn/sec) +1. **Time to First Event (TTFE)** - Latency until first SSE event arrives +1. **Event Throughput** - Events per second (events/sec) + +## Features + +- **True SSE Support**: Properly handles Server-Sent Events streaming without premature connection closure +- **Real-time Metrics**: Live reporting every 5 seconds during tests +- **Comprehensive Tracking**: + - Active connection monitoring + - Connection establishment rate + - Event processing throughput + - TTFE distribution analysis +- **Multiple Interfaces**: + - Web UI for real-time monitoring () + - Headless mode with periodic console updates +- **Detailed Reports**: Final statistics with overall rates and averages +- **Easy Configuration**: Uses existing API key configuration from setup + +## What Gets Measured + +The stress test focuses on SSE streaming performance with these key metrics: + +### Primary Endpoint: `/v1/workflows/run` + +The stress test tests a single endpoint with comprehensive SSE metrics tracking: + +- **Request Type**: POST request to workflow execution API +- **Response Type**: Server-Sent Events (SSE) stream +- **Payload**: Random questions from a configurable pool +- **Concurrency**: Configurable from 1 to 1000+ simultaneous users + +### Key Performance Metrics + +#### 1. **Active Connections** + +- **What it measures**: Number of concurrent SSE connections open at any moment +- **Why it matters**: Shows system's ability to handle parallel streams +- **Good values**: Should remain stable under load without drops + +#### 2. **Connection Rate (conn/sec)** + +- **What it measures**: How fast new SSE connections are established +- **Why it matters**: Indicates system's ability to handle connection spikes +- **Good values**: + - Light load: 5-10 conn/sec + - Medium load: 20-50 conn/sec + - Heavy load: 100+ conn/sec + +#### 3. **Time to First Event (TTFE)** + +- **What it measures**: Latency from request sent to first SSE event received +- **Why it matters**: Critical for user experience - faster TTFE = better perceived performance +- **Good values**: + - Excellent: < 50ms + - Good: 50-100ms + - Acceptable: 100-500ms + - Poor: > 500ms + +#### 4. **Event Throughput (events/sec)** + +- **What it measures**: Rate of SSE events being delivered across all connections +- **Why it matters**: Shows actual data delivery performance +- **Expected values**: Depends on workflow complexity and number of connections + - Single connection: 10-20 events/sec + - 10 connections: 50-100 events/sec + - 100 connections: 200-500 events/sec + +#### 5. **Request/Response Times** + +- **P50 (Median)**: 50% of requests complete within this time +- **P95**: 95% of requests complete within this time +- **P99**: 99% of requests complete within this time +- **Min/Max**: Best and worst case response times + +## Prerequisites + +1. **Dependencies are automatically installed** when running setup: + + - Locust (load testing framework) + - sseclient-py (SSE client library) + +1. **Complete Dify setup**: + + ```bash + # Run the complete setup + python scripts/stress-test/setup_all.py + ``` + +1. **Ensure services are running**: + + **IMPORTANT**: For accurate stress testing, run the API server with Gunicorn in production mode: + + ```bash + # Run from the api directory + cd api + uv run gunicorn \ + --bind 0.0.0.0:5001 \ + --workers 4 \ + --worker-class gevent \ + --timeout 120 \ + --keep-alive 5 \ + --log-level info \ + --access-logfile - \ + --error-logfile - \ + app:app + ``` + + **Configuration options explained**: + + - `--workers 4`: Number of worker processes (adjust based on CPU cores) + - `--worker-class gevent`: Async worker for handling concurrent connections + - `--timeout 120`: Worker timeout for long-running requests + - `--keep-alive 5`: Keep connections alive for SSE streaming + + **NOT RECOMMENDED for stress testing**: + + ```bash + # Debug mode - DO NOT use for stress testing (slow performance) + ./dev/start-api # This runs Flask in debug mode with single-threaded execution + ``` + + **Also start the Mock OpenAI server**: + + ```bash + python scripts/stress-test/setup/mock_openai_server.py + ``` + +## Running the Stress Test + +```bash +# Run with default configuration (headless mode) +./scripts/stress-test/run_locust_stress_test.sh + +# Or run directly with uv +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py --host http://localhost:5001 + +# Run with Web UI (access at http://localhost:8089) +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py --host http://localhost:5001 --web-port 8089 +``` + +The script will: + +1. Validate that all required services are running +1. Check API token availability +1. Execute the Locust stress test with SSE support +1. Generate comprehensive reports in the `reports/` directory + +## Configuration + +The stress test configuration is in `locust.conf`: + +```ini +users = 10 # Number of concurrent users +spawn-rate = 2 # Users spawned per second +run-time = 1m # Test duration (30s, 5m, 1h) +headless = true # Run without web UI +``` + +### Custom Question Sets + +Modify the questions list in `sse_benchmark.py`: + +```python +self.questions = [ + "Your custom question 1", + "Your custom question 2", + # Add more questions... +] +``` + +## Understanding the Results + +### Report Structure + +After running the stress test, you'll find these files in the `reports/` directory: + +- `locust_summary_YYYYMMDD_HHMMSS.txt` - Complete console output with metrics +- `locust_report_YYYYMMDD_HHMMSS.html` - Interactive HTML report with charts +- `locust_YYYYMMDD_HHMMSS_stats.csv` - CSV with detailed statistics +- `locust_YYYYMMDD_HHMMSS_stats_history.csv` - Time-series data + +### Key Metrics + +**Requests Per Second (RPS)**: + +- **Excellent**: > 50 RPS +- **Good**: 20-50 RPS +- **Acceptable**: 10-20 RPS +- **Needs Improvement**: < 10 RPS + +**Response Time Percentiles**: + +- **P50 (Median)**: 50% of requests complete within this time +- **P95**: 95% of requests complete within this time +- **P99**: 99% of requests complete within this time + +**Success Rate**: + +- Should be > 99% for production readiness +- Lower rates indicate errors or timeouts + +### Example Output + +```text +============================================================ +DIFY SSE STRESS TEST +============================================================ + +[2025-09-12 15:45:44,468] Starting test run with 10 users at 2 users/sec + +============================================================ +SSE Metrics | Active: 8 | Total Conn: 142 | Events: 2841 +Rates: 2.4 conn/s | 47.3 events/s | TTFE: 43ms +============================================================ + +Type Name # reqs # fails | Avg Min Max Med | req/s failures/s +---------|------------------------------|--------|--------|--------|--------|--------|--------|--------|----------- +POST /v1/workflows/run 142 0(0.00%) | 41 18 192 38 | 2.37 0.00 +---------|------------------------------|--------|--------|--------|--------|--------|--------|--------|----------- + Aggregated 142 0(0.00%) | 41 18 192 38 | 2.37 0.00 + +============================================================ +FINAL RESULTS +============================================================ +Total Connections: 142 +Total Events: 2841 +Average TTFE: 43 ms +============================================================ +``` + +### How to Read the Results + +**Live SSE Metrics Box (Updates every 10 seconds):** + +```text +SSE Metrics | Active: 8 | Total Conn: 142 | Events: 2841 +Rates: 2.4 conn/s | 47.3 events/s | TTFE: 43ms +``` + +- **Active**: Current number of open SSE connections +- **Total Conn**: Cumulative connections established +- **Events**: Total SSE events received +- **conn/s**: Connection establishment rate +- **events/s**: Event delivery rate +- **TTFE**: Average time to first event + +**Standard Locust Table:** + +```text +Type Name # reqs # fails | Avg Min Max Med | req/s +POST /v1/workflows/run 142 0(0.00%) | 41 18 192 38 | 2.37 +``` + +- **Type**: Always POST for our SSE requests +- **Name**: The API endpoint being tested +- **# reqs**: Total requests made +- **# fails**: Failed requests (should be 0) +- **Avg/Min/Max/Med**: Response time percentiles (ms) +- **req/s**: Request throughput + +**Performance Targets:** + +✅ **Good Performance**: + +- Zero failures (0.00%) +- TTFE < 100ms +- Stable active connections +- Consistent event throughput + +⚠️ **Warning Signs**: + +- Failures > 1% +- TTFE > 500ms +- Dropping active connections +- Declining event rate over time + +## Test Scenarios + +### Light Load + +```yaml +concurrency: 10 +iterations: 100 +``` + +### Normal Load + +```yaml +concurrency: 100 +iterations: 1000 +``` + +### Heavy Load + +```yaml +concurrency: 500 +iterations: 5000 +``` + +### Stress Test + +```yaml +concurrency: 1000 +iterations: 10000 +``` + +## Performance Tuning + +### API Server Optimization + +**Gunicorn Tuning for Different Load Levels**: + +```bash +# Light load (10-50 concurrent users) +uv run gunicorn --bind 0.0.0.0:5001 --workers 2 --worker-class gevent app:app + +# Medium load (50-200 concurrent users) +uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent --worker-connections 1000 app:app + +# Heavy load (200-1000 concurrent users) +uv run gunicorn --bind 0.0.0.0:5001 --workers 8 --worker-class gevent --worker-connections 2000 --max-requests 1000 app:app +``` + +**Worker calculation formula**: + +- Workers = (2 × CPU cores) + 1 +- For SSE/WebSocket: Use gevent worker class +- For CPU-bound tasks: Use sync workers + +### Database Optimization + +**PostgreSQL Connection Pool Tuning**: + +For high-concurrency stress testing, increase the PostgreSQL max connections in `docker/middleware.env`: + +```bash +# Edit docker/middleware.env +POSTGRES_MAX_CONNECTIONS=200 # Default is 100 + +# Recommended values for different load levels: +# Light load (10-50 users): 100 (default) +# Medium load (50-200 users): 200 +# Heavy load (200-1000 users): 500 +``` + +After changing, restart the PostgreSQL container: + +```bash +docker compose -f docker/docker-compose.middleware.yaml down db +docker compose -f docker/docker-compose.middleware.yaml up -d db +``` + +**Note**: Each connection uses ~10MB of RAM. Ensure your database server has sufficient memory: + +- 100 connections: ~1GB RAM +- 200 connections: ~2GB RAM +- 500 connections: ~5GB RAM + +### System Optimizations + +1. **Increase file descriptor limits**: + + ```bash + ulimit -n 65536 + ``` + +1. **TCP tuning for high concurrency** (Linux): + + ```bash + # Increase TCP buffer sizes + sudo sysctl -w net.core.rmem_max=134217728 + sudo sysctl -w net.core.wmem_max=134217728 + + # Enable TCP fast open + sudo sysctl -w net.ipv4.tcp_fastopen=3 + ``` + +1. **macOS specific**: + + ```bash + # Increase maximum connections + sudo sysctl -w kern.ipc.somaxconn=2048 + ``` + +## Troubleshooting + +### Common Issues + +1. **"ModuleNotFoundError: No module named 'locust'"**: + + ```bash + # Dependencies are installed automatically, but if needed: + uv --project api add --dev locust sseclient-py + ``` + +1. **"API key configuration not found"**: + + ```bash + # Run setup + python scripts/stress-test/setup_all.py + ``` + +1. **Services not running**: + + ```bash + # Start Dify API with Gunicorn (production mode) + cd api + uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent app:app + + # Start Mock OpenAI server + python scripts/stress-test/setup/mock_openai_server.py + ``` + +1. **High error rate**: + + - Reduce concurrency level + - Check system resources (CPU, memory) + - Review API server logs for errors + - Increase timeout values if needed + +1. **Permission denied running script**: + + ```bash + chmod +x run_benchmark.sh + ``` + +## Advanced Usage + +### Running Multiple Iterations + +```bash +# Run stress test 3 times with 60-second intervals +for i in {1..3}; do + echo "Run $i of 3" + ./run_locust_stress_test.sh + sleep 60 +done +``` + +### Custom Locust Options + +Run Locust directly with custom options: + +```bash +# With specific user count and spawn rate +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py \ + --host http://localhost:5001 --users 50 --spawn-rate 5 + +# Generate CSV reports +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py \ + --host http://localhost:5001 --csv reports/results + +# Run for specific duration +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py \ + --host http://localhost:5001 --run-time 5m --headless +``` + +### Comparing Results + +```bash +# Compare multiple stress test runs +ls -la reports/stress_test_*.txt | tail -5 +``` + +## Interpreting Performance Issues + +### High Response Times + +Possible causes: + +- Database query performance +- External API latency +- Insufficient server resources +- Network congestion + +### Low Throughput (RPS < 10) + +Check for: + +- CPU bottlenecks +- Memory constraints +- Database connection pooling +- API rate limiting + +### High Error Rate + +Investigate: + +- Server error logs +- Resource exhaustion +- Timeout configurations +- Connection limits + +## Why Locust? + +Locust was chosen over Drill for this stress test because: + +1. **Proper SSE Support**: Correctly handles streaming responses without premature closure +1. **Custom Metrics**: Can track SSE-specific metrics like TTFE and stream duration +1. **Web UI**: Real-time monitoring and control via web interface +1. **Python Integration**: Seamlessly integrates with existing Python setup code +1. **Extensibility**: Easy to customize for specific testing scenarios + +## Contributing + +To improve the stress test suite: + +1. Edit `stress_test.yml` for configuration changes +1. Modify `run_locust_stress_test.sh` for workflow improvements +1. Update question sets for better coverage +1. Add new metrics or analysis features diff --git a/scripts/stress-test/cleanup.py b/scripts/stress-test/cleanup.py new file mode 100755 index 0000000000..5fb1b96ddc --- /dev/null +++ b/scripts/stress-test/cleanup.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 + +import shutil +import sys +from pathlib import Path + +from common import Logger + + +def cleanup() -> None: + """Clean up all configuration files and reports created during setup and stress testing.""" + + log = Logger("Cleanup") + log.header("Stress Test Cleanup") + + config_dir = Path(__file__).parent / "setup" / "config" + reports_dir = Path(__file__).parent / "reports" + + dirs_to_clean = [] + if config_dir.exists(): + dirs_to_clean.append(config_dir) + if reports_dir.exists(): + dirs_to_clean.append(reports_dir) + + if not dirs_to_clean: + log.success("No directories to clean. Everything is already clean.") + return + + log.info("Cleaning up stress test data...") + log.info("This will remove:") + for dir_path in dirs_to_clean: + log.list_item(str(dir_path)) + + # List files that will be deleted + log.separator() + if config_dir.exists(): + config_files = list(config_dir.glob("*.json")) + if config_files: + log.info("Config files to be removed:") + for file in config_files: + log.list_item(file.name) + + if reports_dir.exists(): + report_files = list(reports_dir.glob("*")) + if report_files: + log.info("Report files to be removed:") + for file in report_files: + log.list_item(file.name) + + # Ask for confirmation if running interactively + if sys.stdin.isatty(): + log.separator() + log.warning("This action cannot be undone!") + confirmation = input( + "Are you sure you want to remove all config and report files? (yes/no): " + ) + + if confirmation.lower() not in ["yes", "y"]: + log.error("Cleanup cancelled.") + return + + try: + # Remove directories and all their contents + for dir_path in dirs_to_clean: + shutil.rmtree(dir_path) + log.success(f"{dir_path.name} directory removed successfully!") + + log.separator() + log.info("To run the setup again, execute:") + log.list_item("python setup_all.py") + log.info("Or run scripts individually in this order:") + log.list_item("python setup/mock_openai_server.py (in a separate terminal)") + log.list_item("python setup/setup_admin.py") + log.list_item("python setup/login_admin.py") + log.list_item("python setup/install_openai_plugin.py") + log.list_item("python setup/configure_openai_plugin.py") + log.list_item("python setup/import_workflow_app.py") + log.list_item("python setup/create_api_key.py") + log.list_item("python setup/publish_workflow.py") + log.list_item("python setup/run_workflow.py") + + except PermissionError as e: + log.error(f"Permission denied: {e}") + log.info("Try running with appropriate permissions.") + except Exception as e: + log.error(f"An error occurred during cleanup: {e}") + + +if __name__ == "__main__": + cleanup() diff --git a/scripts/stress-test/common/__init__.py b/scripts/stress-test/common/__init__.py new file mode 100644 index 0000000000..920c31482e --- /dev/null +++ b/scripts/stress-test/common/__init__.py @@ -0,0 +1,6 @@ +"""Common utilities for Dify benchmark suite.""" + +from .config_helper import config_helper +from .logger_helper import Logger, ProgressLogger + +__all__ = ["config_helper", "Logger", "ProgressLogger"] diff --git a/scripts/stress-test/common/config_helper.py b/scripts/stress-test/common/config_helper.py new file mode 100644 index 0000000000..9a3581fb54 --- /dev/null +++ b/scripts/stress-test/common/config_helper.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 + +import json +from pathlib import Path +from typing import Any + + +class ConfigHelper: + """Helper class for reading and writing configuration files.""" + + def __init__(self, base_dir: Path | None = None): + """Initialize ConfigHelper with base directory. + + Args: + base_dir: Base directory for config files. If None, uses setup/config + """ + if base_dir is None: + # Default to config directory in setup folder + base_dir = Path(__file__).parent.parent / "setup" / "config" + self.base_dir = base_dir + self.state_file = "stress_test_state.json" + + def ensure_config_dir(self) -> None: + """Ensure the config directory exists.""" + self.base_dir.mkdir(exist_ok=True, parents=True) + + def get_config_path(self, filename: str) -> Path: + """Get the full path for a config file. + + Args: + filename: Name of the config file (e.g., 'admin_config.json') + + Returns: + Full path to the config file + """ + if not filename.endswith(".json"): + filename += ".json" + return self.base_dir / filename + + def read_config(self, filename: str) -> dict[str, Any] | None: + """Read a configuration file. + + DEPRECATED: Use read_state() or get_state_section() for new code. + This method provides backward compatibility. + + Args: + filename: Name of the config file to read + + Returns: + Dictionary containing config data, or None if file doesn't exist + """ + # Provide backward compatibility for old config names + if filename in ["admin_config", "token_config", "app_config", "api_key_config"]: + section_map = { + "admin_config": "admin", + "token_config": "auth", + "app_config": "app", + "api_key_config": "api_key", + } + return self.get_state_section(section_map[filename]) + + config_path = self.get_config_path(filename) + + if not config_path.exists(): + return None + + try: + with open(config_path, "r") as f: + return json.load(f) + except (json.JSONDecodeError, IOError) as e: + print(f"❌ Error reading {filename}: {e}") + return None + + def write_config(self, filename: str, data: dict[str, Any]) -> bool: + """Write data to a configuration file. + + DEPRECATED: Use write_state() or update_state_section() for new code. + This method provides backward compatibility. + + Args: + filename: Name of the config file to write + data: Dictionary containing data to save + + Returns: + True if successful, False otherwise + """ + # Provide backward compatibility for old config names + if filename in ["admin_config", "token_config", "app_config", "api_key_config"]: + section_map = { + "admin_config": "admin", + "token_config": "auth", + "app_config": "app", + "api_key_config": "api_key", + } + return self.update_state_section(section_map[filename], data) + + self.ensure_config_dir() + config_path = self.get_config_path(filename) + + try: + with open(config_path, "w") as f: + json.dump(data, f, indent=2) + return True + except IOError as e: + print(f"❌ Error writing {filename}: {e}") + return False + + def config_exists(self, filename: str) -> bool: + """Check if a config file exists. + + Args: + filename: Name of the config file to check + + Returns: + True if file exists, False otherwise + """ + return self.get_config_path(filename).exists() + + def delete_config(self, filename: str) -> bool: + """Delete a configuration file. + + Args: + filename: Name of the config file to delete + + Returns: + True if successful, False otherwise + """ + config_path = self.get_config_path(filename) + + if not config_path.exists(): + return True # Already doesn't exist + + try: + config_path.unlink() + return True + except IOError as e: + print(f"❌ Error deleting {filename}: {e}") + return False + + def read_state(self) -> dict[str, Any] | None: + """Read the entire stress test state. + + Returns: + Dictionary containing all state data, or None if file doesn't exist + """ + state_path = self.get_config_path(self.state_file) + if not state_path.exists(): + return None + + try: + with open(state_path, "r") as f: + return json.load(f) + except (json.JSONDecodeError, IOError) as e: + print(f"❌ Error reading {self.state_file}: {e}") + return None + + def write_state(self, data: dict[str, Any]) -> bool: + """Write the entire stress test state. + + Args: + data: Dictionary containing all state data to save + + Returns: + True if successful, False otherwise + """ + self.ensure_config_dir() + state_path = self.get_config_path(self.state_file) + + try: + with open(state_path, "w") as f: + json.dump(data, f, indent=2) + return True + except IOError as e: + print(f"❌ Error writing {self.state_file}: {e}") + return False + + def update_state_section(self, section: str, data: dict[str, Any]) -> bool: + """Update a specific section of the stress test state. + + Args: + section: Name of the section to update (e.g., 'admin', 'auth', 'app', 'api_key') + data: Dictionary containing section data to save + + Returns: + True if successful, False otherwise + """ + state = self.read_state() or {} + state[section] = data + return self.write_state(state) + + def get_state_section(self, section: str) -> dict[str, Any] | None: + """Get a specific section from the stress test state. + + Args: + section: Name of the section to get (e.g., 'admin', 'auth', 'app', 'api_key') + + Returns: + Dictionary containing section data, or None if not found + """ + state = self.read_state() + if state: + return state.get(section) + return None + + def get_token(self) -> str | None: + """Get the access token from auth section. + + Returns: + Access token string or None if not found + """ + auth = self.get_state_section("auth") + if auth: + return auth.get("access_token") + return None + + def get_app_id(self) -> str | None: + """Get the app ID from app section. + + Returns: + App ID string or None if not found + """ + app = self.get_state_section("app") + if app: + return app.get("app_id") + return None + + def get_api_key(self) -> str | None: + """Get the API key token from api_key section. + + Returns: + API key token string or None if not found + """ + api_key = self.get_state_section("api_key") + if api_key: + return api_key.get("token") + return None + + +# Create a default instance for convenience +config_helper = ConfigHelper() diff --git a/scripts/stress-test/common/logger_helper.py b/scripts/stress-test/common/logger_helper.py new file mode 100644 index 0000000000..03436eed6c --- /dev/null +++ b/scripts/stress-test/common/logger_helper.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 + +import sys +import time +from enum import Enum + + +class LogLevel(Enum): + """Log levels with associated colors and symbols.""" + + DEBUG = ("🔍", "\033[90m") # Gray + INFO = ("ℹ️ ", "\033[94m") # Blue + SUCCESS = ("✅", "\033[92m") # Green + WARNING = ("⚠️ ", "\033[93m") # Yellow + ERROR = ("❌", "\033[91m") # Red + STEP = ("🚀", "\033[96m") # Cyan + PROGRESS = ("📋", "\033[95m") # Magenta + + +class Logger: + """Logger class for formatted console output.""" + + def __init__(self, name: str | None = None, use_colors: bool = True): + """Initialize logger. + + Args: + name: Optional name for the logger (e.g., script name) + use_colors: Whether to use ANSI color codes + """ + self.name = name + self.use_colors = use_colors and sys.stdout.isatty() + self._reset_color = "\033[0m" if self.use_colors else "" + + def _format_message(self, level: LogLevel, message: str, indent: int = 0) -> str: + """Format a log message with level, color, and indentation. + + Args: + level: Log level + message: Message to log + indent: Number of spaces to indent + + Returns: + Formatted message string + """ + symbol, color = level.value + color = color if self.use_colors else "" + reset = self._reset_color + + prefix = " " * indent + + if self.name and level in [LogLevel.STEP, LogLevel.ERROR]: + return f"{prefix}{color}{symbol} [{self.name}] {message}{reset}" + else: + return f"{prefix}{color}{symbol} {message}{reset}" + + def debug(self, message: str, indent: int = 0) -> None: + """Log debug message.""" + print(self._format_message(LogLevel.DEBUG, message, indent)) + + def info(self, message: str, indent: int = 0) -> None: + """Log info message.""" + print(self._format_message(LogLevel.INFO, message, indent)) + + def success(self, message: str, indent: int = 0) -> None: + """Log success message.""" + print(self._format_message(LogLevel.SUCCESS, message, indent)) + + def warning(self, message: str, indent: int = 0) -> None: + """Log warning message.""" + print(self._format_message(LogLevel.WARNING, message, indent)) + + def error(self, message: str, indent: int = 0) -> None: + """Log error message.""" + print(self._format_message(LogLevel.ERROR, message, indent), file=sys.stderr) + + def step(self, message: str, indent: int = 0) -> None: + """Log a step in a process.""" + print(self._format_message(LogLevel.STEP, message, indent)) + + def progress(self, message: str, indent: int = 0) -> None: + """Log progress information.""" + print(self._format_message(LogLevel.PROGRESS, message, indent)) + + def separator(self, char: str = "-", length: int = 60) -> None: + """Print a separator line.""" + print(char * length) + + def header(self, title: str, width: int = 60) -> None: + """Print a formatted header.""" + if self.use_colors: + print(f"\n\033[1m{'=' * width}\033[0m") # Bold + print(f"\033[1m{title.center(width)}\033[0m") + print(f"\033[1m{'=' * width}\033[0m\n") + else: + print(f"\n{'=' * width}") + print(title.center(width)) + print(f"{'=' * width}\n") + + def box(self, title: str, width: int = 60) -> None: + """Print a title in a box.""" + border = "═" * (width - 2) + if self.use_colors: + print(f"\033[1m╔{border}╗\033[0m") + print(f"\033[1m║{title.center(width - 2)}║\033[0m") + print(f"\033[1m╚{border}╝\033[0m") + else: + print(f"╔{border}╗") + print(f"║{title.center(width - 2)}║") + print(f"╚{border}╝") + + def list_item(self, item: str, indent: int = 2) -> None: + """Print a list item.""" + prefix = " " * indent + print(f"{prefix}• {item}") + + def key_value(self, key: str, value: str, indent: int = 2) -> None: + """Print a key-value pair.""" + prefix = " " * indent + if self.use_colors: + print(f"{prefix}\033[1m{key}:\033[0m {value}") + else: + print(f"{prefix}{key}: {value}") + + def spinner_start(self, message: str) -> None: + """Start a spinner (simple implementation).""" + sys.stdout.write(f"\r{message}... ") + sys.stdout.flush() + + def spinner_stop(self, success: bool = True, message: str | None = None) -> None: + """Stop the spinner and show result.""" + if success: + symbol = "✅" if message else "Done" + sys.stdout.write(f"\r{symbol} {message or ''}\n") + else: + symbol = "❌" if message else "Failed" + sys.stdout.write(f"\r{symbol} {message or ''}\n") + sys.stdout.flush() + + +class ProgressLogger: + """Logger for tracking progress through multiple steps.""" + + def __init__(self, total_steps: int, logger: Logger | None = None): + """Initialize progress logger. + + Args: + total_steps: Total number of steps + logger: Logger instance to use (creates new if None) + """ + self.total_steps = total_steps + self.current_step = 0 + self.logger = logger or Logger() + self.start_time = time.time() + + def next_step(self, description: str) -> None: + """Move to next step and log it.""" + self.current_step += 1 + elapsed = time.time() - self.start_time + + if self.logger.use_colors: + progress_bar = self._create_progress_bar() + print( + f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}" + ) + self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)") + else: + print(f"\n[Step {self.current_step}/{self.total_steps}]") + self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)") + + def _create_progress_bar(self, width: int = 20) -> str: + """Create a simple progress bar.""" + filled = int(width * self.current_step / self.total_steps) + bar = "█" * filled + "░" * (width - filled) + percentage = int(100 * self.current_step / self.total_steps) + return f"[{bar}] {percentage}%" + + def complete(self) -> None: + """Mark progress as complete.""" + elapsed = time.time() - self.start_time + self.logger.success(f"All steps completed! Total time: {elapsed:.1f}s") + + +# Create default logger instance +logger = Logger() + + +# Convenience functions using default logger +def debug(message: str, indent: int = 0) -> None: + """Log debug message using default logger.""" + logger.debug(message, indent) + + +def info(message: str, indent: int = 0) -> None: + """Log info message using default logger.""" + logger.info(message, indent) + + +def success(message: str, indent: int = 0) -> None: + """Log success message using default logger.""" + logger.success(message, indent) + + +def warning(message: str, indent: int = 0) -> None: + """Log warning message using default logger.""" + logger.warning(message, indent) + + +def error(message: str, indent: int = 0) -> None: + """Log error message using default logger.""" + logger.error(message, indent) + + +def step(message: str, indent: int = 0) -> None: + """Log step using default logger.""" + logger.step(message, indent) + + +def progress(message: str, indent: int = 0) -> None: + """Log progress using default logger.""" + logger.progress(message, indent) diff --git a/scripts/stress-test/locust.conf b/scripts/stress-test/locust.conf new file mode 100644 index 0000000000..87bd8c2870 --- /dev/null +++ b/scripts/stress-test/locust.conf @@ -0,0 +1,37 @@ +# Locust configuration file for Dify SSE benchmark + +# Target host +host = http://localhost:5001 + +# Number of users to simulate +users = 10 + +# Spawn rate (users per second) +spawn-rate = 2 + +# Run time (use format like 30s, 5m, 1h) +run-time = 1m + +# Locustfile to use +locustfile = scripts/stress-test/sse_benchmark.py + +# Headless mode (no web UI) +headless = true + +# Print stats in the console +print-stats = true + +# Only print summary stats +only-summary = false + +# Reset statistics after ramp-up +reset-stats = false + +# Log level +loglevel = INFO + +# CSV output (uncomment to enable) +# csv = reports/locust_results + +# HTML report (uncomment to enable) +# html = reports/locust_report.html \ No newline at end of file diff --git a/scripts/stress-test/run_locust_stress_test.sh b/scripts/stress-test/run_locust_stress_test.sh new file mode 100755 index 0000000000..665cb68754 --- /dev/null +++ b/scripts/stress-test/run_locust_stress_test.sh @@ -0,0 +1,202 @@ +#!/bin/bash + +# Run Dify SSE Stress Test using Locust + +set -e + +# Get the directory where this script is located +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# Go to project root first, then to script dir +PROJECT_ROOT="$( cd "${SCRIPT_DIR}/../.." && pwd )" +cd "${PROJECT_ROOT}" +STRESS_TEST_DIR="scripts/stress-test" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +CYAN='\033[0;36m' +NC='\033[0m' # No Color + +# Configuration +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +REPORT_DIR="${STRESS_TEST_DIR}/reports" +CSV_PREFIX="${REPORT_DIR}/locust_${TIMESTAMP}" +HTML_REPORT="${REPORT_DIR}/locust_report_${TIMESTAMP}.html" +SUMMARY_REPORT="${REPORT_DIR}/locust_summary_${TIMESTAMP}.txt" + +# Create reports directory if it doesn't exist +mkdir -p "${REPORT_DIR}" + +echo -e "${BLUE}╔════════════════════════════════════════════════════════════════╗${NC}" +echo -e "${BLUE}║ DIFY SSE WORKFLOW STRESS TEST (LOCUST) ║${NC}" +echo -e "${BLUE}╚════════════════════════════════════════════════════════════════╝${NC}" +echo + +# Check if services are running +echo -e "${YELLOW}Checking services...${NC}" + +# Check Dify API +if curl -s -f http://localhost:5001/health > /dev/null 2>&1; then + echo -e "${GREEN}✓ Dify API is running${NC}" + + # Warn if running in debug mode (check for werkzeug in process) + if ps aux | grep -v grep | grep -q "werkzeug.*5001\|flask.*run.*5001"; then + echo -e "${YELLOW}⚠ WARNING: API appears to be running in debug mode (Flask development server)${NC}" + echo -e "${YELLOW} This will give inaccurate benchmark results!${NC}" + echo -e "${YELLOW} For accurate benchmarking, restart with Gunicorn:${NC}" + echo -e "${CYAN} cd api && uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent app:app${NC}" + echo + echo -n "Continue anyway? (not recommended) [y/N]: " + read -t 10 continue_debug || continue_debug="n" + if [ "$continue_debug" != "y" ] && [ "$continue_debug" != "Y" ]; then + echo -e "${RED}Benchmark cancelled. Please restart API with Gunicorn.${NC}" + exit 1 + fi + fi +else + echo -e "${RED}✗ Dify API is not running on port 5001${NC}" + echo -e "${YELLOW} Start it with Gunicorn for accurate benchmarking:${NC}" + echo -e "${CYAN} cd api && uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent app:app${NC}" + exit 1 +fi + +# Check Mock OpenAI server +if curl -s -f http://localhost:5004/v1/models > /dev/null 2>&1; then + echo -e "${GREEN}✓ Mock OpenAI server is running${NC}" +else + echo -e "${RED}✗ Mock OpenAI server is not running on port 5004${NC}" + echo -e "${YELLOW} Start it with: python scripts/stress-test/setup/mock_openai_server.py${NC}" + exit 1 +fi + +# Check API token exists +if [ ! -f "${STRESS_TEST_DIR}/setup/config/stress_test_state.json" ]; then + echo -e "${RED}✗ Stress test configuration not found${NC}" + echo -e "${YELLOW} Run setup first: python scripts/stress-test/setup_all.py${NC}" + exit 1 +fi + +API_TOKEN=$(python3 -c "import json; state = json.load(open('${STRESS_TEST_DIR}/setup/config/stress_test_state.json')); print(state.get('api_key', {}).get('token', ''))" 2>/dev/null) +if [ -z "$API_TOKEN" ]; then + echo -e "${RED}✗ Failed to read API token from stress test state${NC}" + exit 1 +fi +echo -e "${GREEN}✓ API token found: ${API_TOKEN:0:10}...${NC}" + +echo +echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" +echo -e "${CYAN} STRESS TEST PARAMETERS ${NC}" +echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" + +# Parse configuration +USERS=$(grep "^users" ${STRESS_TEST_DIR}/locust.conf | cut -d'=' -f2 | tr -d ' ') +SPAWN_RATE=$(grep "^spawn-rate" ${STRESS_TEST_DIR}/locust.conf | cut -d'=' -f2 | tr -d ' ') +RUN_TIME=$(grep "^run-time" ${STRESS_TEST_DIR}/locust.conf | cut -d'=' -f2 | tr -d ' ') + +echo -e " ${YELLOW}Users:${NC} $USERS concurrent users" +echo -e " ${YELLOW}Spawn Rate:${NC} $SPAWN_RATE users/second" +echo -e " ${YELLOW}Duration:${NC} $RUN_TIME" +echo -e " ${YELLOW}Mode:${NC} SSE Streaming" +echo + +# Ask user for run mode +echo -e "${YELLOW}Select run mode:${NC}" +echo " 1) Headless (CLI only) - Default" +echo " 2) Web UI (http://localhost:8089)" +echo -n "Choice [1]: " +read -t 10 choice || choice="1" +echo + +# Use SSE stress test script +LOCUST_SCRIPT="${STRESS_TEST_DIR}/sse_benchmark.py" + +# Prepare Locust command +if [ "$choice" = "2" ]; then + echo -e "${BLUE}Starting Locust with Web UI...${NC}" + echo -e "${YELLOW}Access the web interface at: ${CYAN}http://localhost:8089${NC}" + echo + + # Run with web UI + uv --project api run locust \ + -f ${LOCUST_SCRIPT} \ + --host http://localhost:5001 \ + --web-port 8089 +else + echo -e "${BLUE}Starting stress test in headless mode...${NC}" + echo + + # Run in headless mode with CSV output + uv --project api run locust \ + -f ${LOCUST_SCRIPT} \ + --host http://localhost:5001 \ + --users $USERS \ + --spawn-rate $SPAWN_RATE \ + --run-time $RUN_TIME \ + --headless \ + --print-stats \ + --csv=$CSV_PREFIX \ + --html=$HTML_REPORT \ + 2>&1 | tee $SUMMARY_REPORT + + echo + echo -e "${GREEN}═══════════════════════════════════════════════════════════════${NC}" + echo -e "${GREEN} STRESS TEST COMPLETE ${NC}" + echo -e "${GREEN}═══════════════════════════════════════════════════════════════${NC}" + echo + echo -e "${BLUE}Reports generated:${NC}" + echo -e " ${YELLOW}Summary:${NC} $SUMMARY_REPORT" + echo -e " ${YELLOW}HTML Report:${NC} $HTML_REPORT" + echo -e " ${YELLOW}CSV Stats:${NC} ${CSV_PREFIX}_stats.csv" + echo -e " ${YELLOW}CSV History:${NC} ${CSV_PREFIX}_stats_history.csv" + echo + echo -e "${CYAN}View HTML report:${NC}" + echo " open $HTML_REPORT # macOS" + echo " xdg-open $HTML_REPORT # Linux" + echo + + # Parse and display key metrics + echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" + echo -e "${CYAN} KEY METRICS ${NC}" + echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" + + if [ -f "${CSV_PREFIX}_stats.csv" ]; then + python3 - < None: + """Configure OpenAI plugin with mock server credentials.""" + + log = Logger("ConfigPlugin") + log.header("Configuring OpenAI Plugin") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + log.info("Please run login_admin.py first to get access token") + return + + log.step("Configuring OpenAI plugin with mock server...") + + # API endpoint for plugin configuration + base_url = "http://localhost:5001" + config_endpoint = f"{base_url}/console/api/workspaces/current/model-providers/langgenius/openai/openai/credentials" + + # Configuration payload with mock server + config_payload = { + "credentials": { + "openai_api_key": "apikey", + "openai_organization": None, + "openai_api_base": "http://host.docker.internal:5004", + } + } + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the configuration request + with httpx.Client() as client: + response = client.post( + config_endpoint, + json=config_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200: + log.success("OpenAI plugin configured successfully!") + log.key_value( + "API Base", config_payload["credentials"]["openai_api_base"] + ) + log.key_value( + "API Key", config_payload["credentials"]["openai_api_key"] + ) + + elif response.status_code == 201: + log.success("OpenAI plugin credentials created successfully!") + log.key_value( + "API Base", config_payload["credentials"]["openai_api_base"] + ) + log.key_value( + "API Key", config_payload["credentials"]["openai_api_key"] + ) + + elif response.status_code == 401: + log.error("Configuration failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + else: + log.error( + f"Configuration failed with status code: {response.status_code}" + ) + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + configure_openai_plugin() diff --git a/scripts/stress-test/setup/create_api_key.py b/scripts/stress-test/setup/create_api_key.py new file mode 100755 index 0000000000..08eaed309a --- /dev/null +++ b/scripts/stress-test/setup/create_api_key.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import httpx +import json +from common import config_helper +from common import Logger + + +def create_api_key() -> None: + """Create API key for the imported app.""" + + log = Logger("CreateAPIKey") + log.header("Creating API Key") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + return + + # Read app_id from config + app_id = config_helper.get_app_id() + if not app_id: + log.error("No app_id found in config") + log.info("Please run import_workflow_app.py first to import the app") + return + + log.step(f"Creating API key for app: {app_id}") + + # API endpoint for creating API key + base_url = "http://localhost:5001" + api_key_endpoint = f"{base_url}/console/api/apps/{app_id}/api-keys" + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Length": "0", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the API key creation request + with httpx.Client() as client: + response = client.post( + api_key_endpoint, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200 or response.status_code == 201: + response_data = response.json() + + api_key_id = response_data.get("id") + api_key_token = response_data.get("token") + + if api_key_token: + log.success("API key created successfully!") + log.key_value("Key ID", api_key_id) + log.key_value("Token", api_key_token) + log.key_value("Type", response_data.get("type")) + + # Save API key to config + api_key_config = { + "id": api_key_id, + "token": api_key_token, + "type": response_data.get("type"), + "app_id": app_id, + "created_at": response_data.get("created_at"), + } + + if config_helper.write_config("api_key_config", api_key_config): + log.info( + f"API key saved to: {config_helper.get_config_path('benchmark_state')}" + ) + else: + log.error("No API token received") + log.debug(f"Response: {json.dumps(response_data, indent=2)}") + + elif response.status_code == 401: + log.error("API key creation failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + else: + log.error( + f"API key creation failed with status code: {response.status_code}" + ) + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + create_api_key() diff --git a/scripts/stress-test/setup/dsl/workflow_llm.yml b/scripts/stress-test/setup/dsl/workflow_llm.yml new file mode 100644 index 0000000000..c0fd2c7d8b --- /dev/null +++ b/scripts/stress-test/setup/dsl/workflow_llm.yml @@ -0,0 +1,176 @@ +app: + description: '' + icon: 🤖 + icon_background: '#FFEAD5' + mode: workflow + name: workflow_llm + use_icon_as_answer_icon: false +dependencies: +- current_identifier: null + type: marketplace + value: + marketplace_plugin_unique_identifier: langgenius/openai:0.2.5@373362a028986aae53a7baf73a7f11991ba3c22c69eaf97d6cde048cfd4a9f98 +kind: app +version: 0.4.0 +workflow: + conversation_variables: [] + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_size_limit: 10 + video_file_size_limit: 100 + workflow_file_upload_limit: 10 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - data: + isInIteration: false + isInLoop: false + sourceType: start + targetType: llm + id: 1757611990947-source-1757611992921-target + source: '1757611990947' + sourceHandle: source + target: '1757611992921' + targetHandle: target + type: custom + zIndex: 0 + - data: + isInIteration: false + isInLoop: false + sourceType: llm + targetType: end + id: 1757611992921-source-1757611996447-target + source: '1757611992921' + sourceHandle: source + target: '1757611996447' + targetHandle: target + type: custom + zIndex: 0 + nodes: + - data: + desc: '' + selected: false + title: Start + type: start + variables: + - label: question + max_length: null + options: [] + required: true + type: text-input + variable: question + height: 90 + id: '1757611990947' + position: + x: 30 + y: 245 + positionAbsolute: + x: 30 + y: 245 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 244 + - data: + context: + enabled: false + variable_selector: [] + desc: '' + model: + completion_params: + temperature: 0.7 + mode: chat + name: gpt-4o + provider: langgenius/openai/openai + prompt_template: + - id: c165fcb6-f1f0-42f2-abab-e81982434deb + role: system + text: '' + - role: user + text: '{{#1757611990947.question#}}' + selected: false + title: LLM + type: llm + variables: [] + vision: + enabled: false + height: 90 + id: '1757611992921' + position: + x: 334 + y: 245 + positionAbsolute: + x: 334 + y: 245 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 244 + - data: + desc: '' + outputs: + - value_selector: + - '1757611992921' + - text + value_type: string + variable: answer + selected: false + title: End + type: end + height: 90 + id: '1757611996447' + position: + x: 638 + y: 245 + positionAbsolute: + x: 638 + y: 245 + selected: true + sourcePosition: right + targetPosition: left + type: custom + width: 244 + viewport: + x: 0 + y: 0 + zoom: 0.7 diff --git a/scripts/stress-test/setup/import_workflow_app.py b/scripts/stress-test/setup/import_workflow_app.py new file mode 100755 index 0000000000..1d8a9b6009 --- /dev/null +++ b/scripts/stress-test/setup/import_workflow_app.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import httpx +import json +from common import config_helper, Logger + + +def import_workflow_app() -> None: + """Import workflow app from DSL file and save app_id.""" + + log = Logger("ImportApp") + log.header("Importing Workflow Application") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + log.info("Please run login_admin.py first to get access token") + return + + # Read workflow DSL file + dsl_path = Path(__file__).parent / "dsl" / "workflow_llm.yml" + + if not dsl_path.exists(): + log.error(f"DSL file not found: {dsl_path}") + return + + with open(dsl_path, "r") as f: + yaml_content = f.read() + + log.step("Importing workflow app from DSL...") + log.key_value("DSL file", dsl_path.name) + + # API endpoint for app import + base_url = "http://localhost:5001" + import_endpoint = f"{base_url}/console/api/apps/imports" + + # Import payload + import_payload = {"mode": "yaml-content", "yaml_content": yaml_content} + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the import request + with httpx.Client() as client: + response = client.post( + import_endpoint, + json=import_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200: + response_data = response.json() + + # Check import status + if response_data.get("status") == "completed": + app_id = response_data.get("app_id") + + if app_id: + log.success("Workflow app imported successfully!") + log.key_value("App ID", app_id) + log.key_value("App Mode", response_data.get("app_mode")) + log.key_value( + "DSL Version", response_data.get("imported_dsl_version") + ) + + # Save app_id to config + app_config = { + "app_id": app_id, + "app_mode": response_data.get("app_mode"), + "app_name": "workflow_llm", + "dsl_version": response_data.get("imported_dsl_version"), + } + + if config_helper.write_config("app_config", app_config): + log.info( + f"App config saved to: {config_helper.get_config_path('benchmark_state')}" + ) + else: + log.error("Import completed but no app_id received") + log.debug(f"Response: {json.dumps(response_data, indent=2)}") + + elif response_data.get("status") == "failed": + log.error("Import failed") + log.error(f"Error: {response_data.get('error')}") + else: + log.warning(f"Import status: {response_data.get('status')}") + log.debug(f"Response: {json.dumps(response_data, indent=2)}") + + elif response.status_code == 401: + log.error("Import failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + else: + log.error(f"Import failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + import_workflow_app() diff --git a/scripts/stress-test/setup/install_openai_plugin.py b/scripts/stress-test/setup/install_openai_plugin.py new file mode 100755 index 0000000000..d925126dae --- /dev/null +++ b/scripts/stress-test/setup/install_openai_plugin.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import httpx +import time +from common import config_helper +from common import Logger + + +def install_openai_plugin() -> None: + """Install OpenAI plugin using saved access token.""" + + log = Logger("InstallPlugin") + log.header("Installing OpenAI Plugin") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + log.info("Please run login_admin.py first to get access token") + return + + log.step("Installing OpenAI plugin...") + + # API endpoint for plugin installation + base_url = "http://localhost:5001" + install_endpoint = ( + f"{base_url}/console/api/workspaces/current/plugin/install/marketplace" + ) + + # Plugin identifier + plugin_payload = { + "plugin_unique_identifiers": [ + "langgenius/openai:0.2.5@373362a028986aae53a7baf73a7f11991ba3c22c69eaf97d6cde048cfd4a9f98" + ] + } + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the installation request + with httpx.Client() as client: + response = client.post( + install_endpoint, + json=plugin_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200: + response_data = response.json() + task_id = response_data.get("task_id") + + if not task_id: + log.error("No task ID received from installation request") + return + + log.progress(f"Installation task created: {task_id}") + log.info("Polling for task completion...") + + # Poll for task completion + task_endpoint = ( + f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}" + ) + + max_attempts = 30 # 30 attempts with 2 second delay = 60 seconds max + attempt = 0 + + log.spinner_start("Installing plugin") + + while attempt < max_attempts: + attempt += 1 + time.sleep(2) # Wait 2 seconds between polls + + task_response = client.get( + task_endpoint, + headers=headers, + cookies=cookies, + ) + + if task_response.status_code != 200: + log.spinner_stop( + success=False, + message=f"Failed to get task status: {task_response.status_code}", + ) + return + + task_data = task_response.json() + task_info = task_data.get("task", {}) + status = task_info.get("status") + + if status == "success": + log.spinner_stop(success=True, message="Plugin installed!") + log.success("OpenAI plugin installed successfully!") + + # Display plugin info + plugins = task_info.get("plugins", []) + if plugins: + plugin_info = plugins[0] + log.key_value("Plugin ID", plugin_info.get("plugin_id")) + log.key_value("Message", plugin_info.get("message")) + break + + elif status == "failed": + log.spinner_stop(success=False, message="Installation failed") + log.error("Plugin installation failed") + plugins = task_info.get("plugins", []) + if plugins: + for plugin in plugins: + log.list_item( + f"{plugin.get('plugin_id')}: {plugin.get('message')}" + ) + break + + # Continue polling if status is "pending" or other + + else: + log.spinner_stop(success=False, message="Installation timed out") + log.error("Installation timed out after 60 seconds") + + elif response.status_code == 401: + log.error("Installation failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + elif response.status_code == 409: + log.warning("Plugin may already be installed") + log.debug(f"Response: {response.text}") + else: + log.error( + f"Installation failed with status code: {response.status_code}" + ) + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + install_openai_plugin() diff --git a/scripts/stress-test/setup/login_admin.py b/scripts/stress-test/setup/login_admin.py new file mode 100755 index 0000000000..0e3da687f6 --- /dev/null +++ b/scripts/stress-test/setup/login_admin.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import httpx +import json +from common import config_helper +from common import Logger + + +def login_admin() -> None: + """Login with admin account and save access token.""" + + log = Logger("Login") + log.header("Admin Login") + + # Read admin credentials from config + admin_config = config_helper.read_config("admin_config") + + if not admin_config: + log.error("Admin config not found") + log.info("Please run setup_admin.py first to create the admin account") + return + + log.info(f"Logging in with email: {admin_config['email']}") + + # API login endpoint + base_url = "http://localhost:5001" + login_endpoint = f"{base_url}/console/api/login" + + # Prepare login payload + login_payload = { + "email": admin_config["email"], + "password": admin_config["password"], + "remember_me": True, + } + + try: + # Make the login request + with httpx.Client() as client: + response = client.post( + login_endpoint, + json=login_payload, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code == 200: + log.success("Login successful!") + + # Extract token from response + response_data = response.json() + + # Check if login was successful + if response_data.get("result") != "success": + log.error(f"Login failed: {response_data}") + return + + # Extract tokens from data field + token_data = response_data.get("data", {}) + access_token = token_data.get("access_token", "") + refresh_token = token_data.get("refresh_token", "") + + if not access_token: + log.error("No access token found in response") + log.debug(f"Full response: {json.dumps(response_data, indent=2)}") + return + + # Save token to config file + token_config = { + "email": admin_config["email"], + "access_token": access_token, + "refresh_token": refresh_token, + } + + # Save token config + if config_helper.write_config("token_config", token_config): + log.info( + f"Token saved to: {config_helper.get_config_path('benchmark_state')}" + ) + + # Show truncated token for verification + token_display = ( + f"{access_token[:20]}..." + if len(access_token) > 20 + else "Token saved" + ) + log.key_value("Access token", token_display) + + elif response.status_code == 401: + log.error("Login failed: Invalid credentials") + log.debug(f"Response: {response.text}") + else: + log.error(f"Login failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + login_admin() diff --git a/scripts/stress-test/setup/mock_openai_server.py b/scripts/stress-test/setup/mock_openai_server.py new file mode 100755 index 0000000000..9e3e8d590a --- /dev/null +++ b/scripts/stress-test/setup/mock_openai_server.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 + +import json +import time +import uuid +from typing import Any, Iterator +from flask import Flask, request, jsonify, Response + +app = Flask(__name__) + +# Mock models list +MODELS = [ + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1677649963, + "owned_by": "openai", + }, + {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, + { + "id": "text-embedding-ada-002", + "object": "model", + "created": 1671217299, + "owned_by": "openai-internal", + }, +] + + +@app.route("/v1/models", methods=["GET"]) +def list_models() -> Any: + """List available models.""" + return jsonify({"object": "list", "data": MODELS}) + + +@app.route("/v1/chat/completions", methods=["POST"]) +def chat_completions() -> Any: + """Handle chat completions.""" + data = request.json or {} + model = data.get("model", "gpt-3.5-turbo") + messages = data.get("messages", []) + stream = data.get("stream", False) + + # Generate mock response + response_content = "This is a mock response from the OpenAI server." + if messages: + last_message = messages[-1].get("content", "") + response_content = f"Mock response to: {last_message[:100]}..." + + if stream: + # Streaming response + def generate() -> Iterator[str]: + # Send initial chunk + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(chunk)}\n\n" + + # Send content in chunks + words = response_content.split() + for word in words: + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": word + " "}, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(chunk)}\n\n" + time.sleep(0.05) # Simulate streaming delay + + # Send final chunk + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + + return Response(generate(), mimetype="text/event-stream") + else: + # Non-streaming response + return jsonify( + { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": response_content}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": len(str(messages)), + "completion_tokens": len(response_content.split()), + "total_tokens": len(str(messages)) + len(response_content.split()), + }, + } + ) + + +@app.route("/v1/completions", methods=["POST"]) +def completions() -> Any: + """Handle text completions.""" + data = request.json or {} + model = data.get("model", "gpt-3.5-turbo-instruct") + prompt = data.get("prompt", "") + + response_text = f"Mock completion for prompt: {prompt[:100]}..." + + return jsonify( + { + "id": f"cmpl-{uuid.uuid4().hex[:8]}", + "object": "text_completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "text": response_text, + "index": 0, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": len(prompt.split()), + "completion_tokens": len(response_text.split()), + "total_tokens": len(prompt.split()) + len(response_text.split()), + }, + } + ) + + +@app.route("/v1/embeddings", methods=["POST"]) +def embeddings() -> Any: + """Handle embeddings requests.""" + data = request.json or {} + model = data.get("model", "text-embedding-ada-002") + input_text = data.get("input", "") + + # Generate mock embedding (1536 dimensions for ada-002) + mock_embedding = [0.1] * 1536 + + return jsonify( + { + "object": "list", + "data": [{"object": "embedding", "embedding": mock_embedding, "index": 0}], + "model": model, + "usage": { + "prompt_tokens": len(input_text.split()), + "total_tokens": len(input_text.split()), + }, + } + ) + + +@app.route("/v1/models/", methods=["GET"]) +def get_model(model_id: str) -> tuple[Any, int] | Any: + """Get specific model details.""" + for model in MODELS: + if model["id"] == model_id: + return jsonify(model) + + return jsonify({"error": "Model not found"}), 404 + + +@app.route("/health", methods=["GET"]) +def health() -> Any: + """Health check endpoint.""" + return jsonify({"status": "healthy"}) + + +if __name__ == "__main__": + print("🚀 Starting Mock OpenAI Server on http://localhost:5004") + print("Available endpoints:") + print(" - GET /v1/models") + print(" - POST /v1/chat/completions") + print(" - POST /v1/completions") + print(" - POST /v1/embeddings") + print(" - GET /v1/models/") + print(" - GET /health") + app.run(host="0.0.0.0", port=5004, debug=True) diff --git a/scripts/stress-test/setup/publish_workflow.py b/scripts/stress-test/setup/publish_workflow.py new file mode 100755 index 0000000000..f23db5049c --- /dev/null +++ b/scripts/stress-test/setup/publish_workflow.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import httpx +import json +from common import config_helper +from common import Logger + + +def publish_workflow() -> None: + """Publish the imported workflow app.""" + + log = Logger("PublishWorkflow") + log.header("Publishing Workflow") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + return + + # Read app_id from config + app_id = config_helper.get_app_id() + if not app_id: + log.error("No app_id found in config") + return + + log.step(f"Publishing workflow for app: {app_id}") + + # API endpoint for publishing workflow + base_url = "http://localhost:5001" + publish_endpoint = f"{base_url}/console/api/apps/{app_id}/workflows/publish" + + # Publish payload + publish_payload = {"marked_name": "", "marked_comment": ""} + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the publish request + with httpx.Client() as client: + response = client.post( + publish_endpoint, + json=publish_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200 or response.status_code == 201: + log.success("Workflow published successfully!") + log.key_value("App ID", app_id) + + # Try to parse response if it has JSON content + if response.text: + try: + response_data = response.json() + if response_data: + log.debug( + f"Response: {json.dumps(response_data, indent=2)}" + ) + except json.JSONDecodeError: + # Response might be empty or non-JSON + pass + + elif response.status_code == 401: + log.error("Workflow publish failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + elif response.status_code == 404: + log.error("Workflow publish failed: App not found") + log.info("Make sure the app was imported successfully") + else: + log.error( + f"Workflow publish failed with status code: {response.status_code}" + ) + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + publish_workflow() diff --git a/scripts/stress-test/setup/run_workflow.py b/scripts/stress-test/setup/run_workflow.py new file mode 100755 index 0000000000..13ab1280fc --- /dev/null +++ b/scripts/stress-test/setup/run_workflow.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import httpx +import json +from common import config_helper, Logger + + +def run_workflow(question: str = "fake question", streaming: bool = True) -> None: + """Run the workflow app with a question.""" + + log = Logger("RunWorkflow") + log.header("Running Workflow") + + # Read API key from config + api_token = config_helper.get_api_key() + if not api_token: + log.error("No API token found in config") + log.info("Please run create_api_key.py first to create an API key") + return + + log.key_value("Question", question) + log.key_value("Mode", "Streaming" if streaming else "Blocking") + log.separator() + + # API endpoint for running workflow + base_url = "http://localhost:5001" + run_endpoint = f"{base_url}/v1/workflows/run" + + # Run payload + run_payload = { + "inputs": {"question": question}, + "user": "default user", + "response_mode": "streaming" if streaming else "blocking", + } + + headers = { + "Authorization": f"Bearer {api_token}", + "Content-Type": "application/json", + } + + try: + # Make the run request + with httpx.Client(timeout=30.0) as client: + if streaming: + # Handle streaming response + with client.stream( + "POST", + run_endpoint, + json=run_payload, + headers=headers, + ) as response: + if response.status_code == 200: + log.success("Workflow started successfully!") + log.separator() + log.step("Streaming response:") + + for line in response.iter_lines(): + if line.startswith("data: "): + data_str = line[6:] # Remove "data: " prefix + if data_str == "[DONE]": + log.success("Workflow completed!") + break + try: + data = json.loads(data_str) + event = data.get("event") + + if event == "workflow_started": + log.progress( + f"Workflow started: {data.get('data', {}).get('id')}" + ) + elif event == "node_started": + node_data = data.get("data", {}) + log.progress( + f"Node started: {node_data.get('node_type')} - {node_data.get('title')}" + ) + elif event == "node_finished": + node_data = data.get("data", {}) + log.progress( + f"Node finished: {node_data.get('node_type')} - {node_data.get('title')}" + ) + + # Print output if it's the LLM node + outputs = node_data.get("outputs", {}) + if outputs.get("text"): + log.separator() + log.info("💬 LLM Response:") + log.info(outputs.get("text"), indent=2) + log.separator() + + elif event == "workflow_finished": + workflow_data = data.get("data", {}) + outputs = workflow_data.get("outputs", {}) + if outputs.get("answer"): + log.separator() + log.info("📤 Final Answer:") + log.info(outputs.get("answer"), indent=2) + log.separator() + log.key_value( + "Total tokens", + str(workflow_data.get("total_tokens", 0)), + ) + log.key_value( + "Total steps", + str(workflow_data.get("total_steps", 0)), + ) + + elif event == "error": + log.error(f"Error: {data.get('message')}") + + except json.JSONDecodeError: + # Some lines might not be JSON + pass + else: + log.error( + f"Workflow run failed with status code: {response.status_code}" + ) + log.debug(f"Response: {response.text}") + else: + # Handle blocking response + response = client.post( + run_endpoint, + json=run_payload, + headers=headers, + ) + + if response.status_code == 200: + log.success("Workflow completed successfully!") + response_data = response.json() + + log.separator() + log.debug(f"Full response: {json.dumps(response_data, indent=2)}") + + # Extract the answer if available + outputs = response_data.get("data", {}).get("outputs", {}) + if outputs.get("answer"): + log.separator() + log.info("📤 Final Answer:") + log.info(outputs.get("answer"), indent=2) + else: + log.error( + f"Workflow run failed with status code: {response.status_code}" + ) + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except httpx.TimeoutException: + log.error("Request timed out") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + # Allow passing question as command line argument + if len(sys.argv) > 1: + question = " ".join(sys.argv[1:]) + else: + question = "What is the capital of France?" + + run_workflow(question=question, streaming=True) diff --git a/scripts/stress-test/setup/setup_admin.py b/scripts/stress-test/setup/setup_admin.py new file mode 100755 index 0000000000..3cc83aa773 --- /dev/null +++ b/scripts/stress-test/setup/setup_admin.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import httpx +from common import config_helper, Logger + + +def setup_admin_account() -> None: + """Setup Dify API with an admin account.""" + + log = Logger("SetupAdmin") + log.header("Setting up Admin Account") + + # Admin account credentials + admin_config = { + "email": "test@dify.ai", + "username": "dify", + "password": "password123", + } + + # Save credentials to config file + if config_helper.write_config("admin_config", admin_config): + log.info( + f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}" + ) + + # API setup endpoint + base_url = "http://localhost:5001" + setup_endpoint = f"{base_url}/console/api/setup" + + # Prepare setup payload + setup_payload = { + "email": admin_config["email"], + "name": admin_config["username"], + "password": admin_config["password"], + } + + log.step("Configuring Dify with admin account...") + + try: + # Make the setup request + with httpx.Client() as client: + response = client.post( + setup_endpoint, + json=setup_payload, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code == 201: + log.success("Admin account created successfully!") + log.key_value("Email", admin_config["email"]) + log.key_value("Username", admin_config["username"]) + + elif response.status_code == 400: + log.warning( + "Setup may have already been completed or invalid data provided" + ) + log.debug(f"Response: {response.text}") + else: + log.error(f"Setup failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + setup_admin_account() diff --git a/scripts/stress-test/setup_all.py b/scripts/stress-test/setup_all.py new file mode 100755 index 0000000000..7a4c5c01b2 --- /dev/null +++ b/scripts/stress-test/setup_all.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 + +import subprocess +import sys +import time +import socket +from pathlib import Path + +from common import Logger, ProgressLogger + + +def run_script(script_name: str, description: str) -> bool: + """Run a Python script and return success status.""" + script_path = Path(__file__).parent / "setup" / script_name + + if not script_path.exists(): + print(f"❌ Script not found: {script_path}") + return False + + print(f"\n{'=' * 60}") + print(f"🚀 {description}") + print(f" Running: {script_name}") + print(f"{'=' * 60}") + + try: + result = subprocess.run( + [sys.executable, str(script_path)], + capture_output=True, + text=True, + check=False, + ) + + # Print output + if result.stdout: + print(result.stdout) + if result.stderr: + print(result.stderr, file=sys.stderr) + + if result.returncode != 0: + print(f"❌ Script failed with exit code: {result.returncode}") + return False + + print(f"✅ {script_name} completed successfully") + return True + + except Exception as e: + print(f"❌ Error running {script_name}: {e}") + return False + + +def check_port(host: str, port: int, service_name: str) -> bool: + """Check if a service is running on the specified port.""" + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2) + result = sock.connect_ex((host, port)) + sock.close() + + if result == 0: + Logger().success(f"{service_name} is running on port {port}") + return True + else: + Logger().error(f"{service_name} is not accessible on port {port}") + return False + except Exception as e: + Logger().error(f"Error checking {service_name}: {e}") + return False + + +def main() -> None: + """Run all setup scripts in order.""" + + log = Logger("Setup") + log.box("Dify Stress Test Setup - Full Installation") + + # Check if required services are running + log.step("Checking required services...") + log.separator() + + dify_running = check_port("localhost", 5001, "Dify API server") + if not dify_running: + log.info("To start Dify API server:") + log.list_item("Run: ./dev/start-api") + + mock_running = check_port("localhost", 5004, "Mock OpenAI server") + if not mock_running: + log.info("To start Mock OpenAI server:") + log.list_item("Run: python scripts/stress-test/setup/mock_openai_server.py") + + if not dify_running or not mock_running: + print("\n⚠️ Both services must be running before proceeding.") + retry = input("\nWould you like to check again? (yes/no): ") + if retry.lower() in ["yes", "y"]: + return main() # Recursively call main to check again + else: + print( + "❌ Setup cancelled. Please start the required services and try again." + ) + sys.exit(1) + + log.success("All required services are running!") + input("\nPress Enter to continue with setup...") + + # Define setup steps + setup_steps = [ + ("setup_admin.py", "Creating admin account"), + ("login_admin.py", "Logging in and getting access token"), + ("install_openai_plugin.py", "Installing OpenAI plugin"), + ("configure_openai_plugin.py", "Configuring OpenAI plugin with mock server"), + ("import_workflow_app.py", "Importing workflow application"), + ("create_api_key.py", "Creating API key for the app"), + ("publish_workflow.py", "Publishing the workflow"), + ] + + # Create progress logger + progress = ProgressLogger(len(setup_steps), log) + failed_step = None + + for script, description in setup_steps: + progress.next_step(description) + success = run_script(script, description) + + if not success: + failed_step = script + break + + # Small delay between steps + time.sleep(1) + + log.separator() + + if failed_step: + log.error(f"Setup failed at: {failed_step}") + log.separator() + log.info("Troubleshooting:") + log.list_item("Check if the Dify API server is running (./dev/start-api)") + log.list_item("Check if the mock OpenAI server is running (port 5004)") + log.list_item("Review the error messages above") + log.list_item("Run cleanup.py and try again") + sys.exit(1) + else: + progress.complete() + log.separator() + log.success("Setup completed successfully!") + log.info("Next steps:") + log.list_item("Test the workflow:") + log.info( + ' python scripts/stress-test/setup/run_workflow.py "Your question here"', + indent=4, + ) + log.list_item("To clean up and start over:") + log.info(" python scripts/stress-test/cleanup.py", indent=4) + + # Optionally run a test + log.separator() + test_input = input("Would you like to run a test workflow now? (yes/no): ") + + if test_input.lower() in ["yes", "y"]: + log.step("Running test workflow...") + run_script("run_workflow.py", "Testing workflow with default question") + + +if __name__ == "__main__": + main() diff --git a/scripts/stress-test/sse_benchmark.py b/scripts/stress-test/sse_benchmark.py new file mode 100644 index 0000000000..e08fe6c33b --- /dev/null +++ b/scripts/stress-test/sse_benchmark.py @@ -0,0 +1,770 @@ +#!/usr/bin/env python3 +""" +SSE (Server-Sent Events) Stress Test for Dify Workflow API + +This script stress tests the streaming performance of Dify's workflow execution API, +measuring key metrics like connection rate, event throughput, and time to first event (TTFE). +""" + +import json +import time +import random +import sys +import threading +import os +import logging +import statistics +from pathlib import Path +from collections import deque +from datetime import datetime +from dataclasses import dataclass, asdict +from locust import HttpUser, task, between, events, constant +from typing import TypedDict, Literal, TypeAlias +import requests.exceptions + +# Add the stress-test directory to path to import common modules +sys.path.insert(0, str(Path(__file__).parent)) +from common.config_helper import ConfigHelper # type: ignore[import-not-found] + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +# Configuration from environment +WORKFLOW_PATH = os.getenv("WORKFLOW_PATH", "/v1/workflows/run") +CONNECT_TIMEOUT = float(os.getenv("CONNECT_TIMEOUT", "10")) +READ_TIMEOUT = float(os.getenv("READ_TIMEOUT", "60")) +TERMINAL_EVENTS = [e.strip() for e in os.getenv("TERMINAL_EVENTS", "workflow_finished,error").split(",") if e.strip()] +QUESTIONS_FILE = os.getenv("QUESTIONS_FILE", "") + + +# Type definitions +ErrorType: TypeAlias = Literal[ + "connection_error", + "timeout", + "invalid_json", + "http_4xx", + "http_5xx", + "early_termination", + "invalid_response", +] + + +class ErrorCounts(TypedDict): + """Error count tracking""" + connection_error: int + timeout: int + invalid_json: int + http_4xx: int + http_5xx: int + early_termination: int + invalid_response: int + + +class SSEEvent(TypedDict): + """Server-Sent Event structure""" + data: str + event: str + id: str | None + + +class WorkflowInputs(TypedDict): + """Workflow input structure""" + question: str + + +class WorkflowRequestData(TypedDict): + """Workflow request payload""" + inputs: WorkflowInputs + response_mode: Literal["streaming"] + user: str + + +class ParsedEventData(TypedDict, total=False): + """Parsed event data from SSE stream""" + event: str + task_id: str + workflow_run_id: str + data: object # For dynamic content + created_at: int + + +class LocustStats(TypedDict): + """Locust statistics structure""" + total_requests: int + total_failures: int + avg_response_time: float + min_response_time: float + max_response_time: float + + +class ReportData(TypedDict): + """JSON report structure""" + timestamp: str + duration_seconds: float + metrics: dict[str, object] # Metrics as dict for JSON serialization + locust_stats: LocustStats | None + + +@dataclass +class StreamMetrics: + """Metrics for a single stream""" + + stream_duration: float + events_count: int + bytes_received: int + ttfe: float + inter_event_times: list[float] + + +@dataclass +class MetricsSnapshot: + """Snapshot of current metrics state""" + + active_connections: int + total_connections: int + total_events: int + connection_rate: float + event_rate: float + overall_conn_rate: float + overall_event_rate: float + ttfe_avg: float + ttfe_min: float + ttfe_max: float + ttfe_p50: float + ttfe_p95: float + ttfe_samples: int + ttfe_total_samples: int # Total TTFE samples collected (not limited by window) + error_counts: ErrorCounts + stream_duration_avg: float + stream_duration_p50: float + stream_duration_p95: float + events_per_stream_avg: float + inter_event_latency_avg: float + inter_event_latency_p50: float + inter_event_latency_p95: float + + +class MetricsTracker: + def __init__(self) -> None: + self.lock = threading.Lock() + self.active_connections = 0 + self.total_connections = 0 + self.total_events = 0 + self.start_time = time.time() + + # Enhanced metrics with memory limits + self.max_samples = 10000 # Prevent unbounded growth + self.ttfe_samples: deque[float] = deque(maxlen=self.max_samples) + self.ttfe_total_count = 0 # Track total TTFE samples collected + + # For rate calculations - no maxlen to avoid artificial limits + self.connection_times: deque[float] = deque() + self.event_times: deque[float] = deque() + self.last_stats_time = time.time() + self.last_total_connections = 0 + self.last_total_events = 0 + self.stream_metrics: deque[StreamMetrics] = deque(maxlen=self.max_samples) + self.error_counts: ErrorCounts = ErrorCounts( + connection_error=0, + timeout=0, + invalid_json=0, + http_4xx=0, + http_5xx=0, + early_termination=0, + invalid_response=0, + ) + + def connection_started(self) -> None: + with self.lock: + self.active_connections += 1 + self.total_connections += 1 + self.connection_times.append(time.time()) + + def connection_ended(self) -> None: + with self.lock: + self.active_connections -= 1 + + def event_received(self) -> None: + with self.lock: + self.total_events += 1 + self.event_times.append(time.time()) + + def record_ttfe(self, ttfe_ms: float) -> None: + with self.lock: + self.ttfe_samples.append(ttfe_ms) # deque handles maxlen + self.ttfe_total_count += 1 # Increment total counter + + def record_stream_metrics(self, metrics: StreamMetrics) -> None: + with self.lock: + self.stream_metrics.append(metrics) # deque handles maxlen + + def record_error(self, error_type: ErrorType) -> None: + with self.lock: + self.error_counts[error_type] += 1 + + def get_stats(self) -> MetricsSnapshot: + with self.lock: + current_time = time.time() + time_window = 10.0 # 10 second window for rate calculation + + # Clean up old timestamps outside the window + cutoff_time = current_time - time_window + while self.connection_times and self.connection_times[0] < cutoff_time: + self.connection_times.popleft() + while self.event_times and self.event_times[0] < cutoff_time: + self.event_times.popleft() + + # Calculate rates based on actual window or elapsed time + window_duration = min(time_window, current_time - self.start_time) + if window_duration > 0: + conn_rate = len(self.connection_times) / window_duration + event_rate = len(self.event_times) / window_duration + else: + conn_rate = 0 + event_rate = 0 + + # Calculate TTFE statistics + if self.ttfe_samples: + avg_ttfe = statistics.mean(self.ttfe_samples) + min_ttfe = min(self.ttfe_samples) + max_ttfe = max(self.ttfe_samples) + p50_ttfe = statistics.median(self.ttfe_samples) + if len(self.ttfe_samples) >= 2: + quantiles = statistics.quantiles( + self.ttfe_samples, n=20, method="inclusive" + ) + p95_ttfe = quantiles[18] # 19th of 19 quantiles = 95th percentile + else: + p95_ttfe = max_ttfe + else: + avg_ttfe = min_ttfe = max_ttfe = p50_ttfe = p95_ttfe = 0 + + # Calculate stream metrics + if self.stream_metrics: + durations = [m.stream_duration for m in self.stream_metrics] + events_per_stream = [m.events_count for m in self.stream_metrics] + stream_duration_avg = statistics.mean(durations) + stream_duration_p50 = statistics.median(durations) + stream_duration_p95 = ( + statistics.quantiles(durations, n=20, method="inclusive")[18] + if len(durations) >= 2 + else max(durations) + if durations + else 0 + ) + events_per_stream_avg = ( + statistics.mean(events_per_stream) if events_per_stream else 0 + ) + + # Calculate inter-event latency statistics + all_inter_event_times = [] + for m in self.stream_metrics: + all_inter_event_times.extend(m.inter_event_times) + + if all_inter_event_times: + inter_event_latency_avg = statistics.mean(all_inter_event_times) + inter_event_latency_p50 = statistics.median(all_inter_event_times) + inter_event_latency_p95 = ( + statistics.quantiles( + all_inter_event_times, n=20, method="inclusive" + )[18] + if len(all_inter_event_times) >= 2 + else max(all_inter_event_times) + ) + else: + inter_event_latency_avg = inter_event_latency_p50 = ( + inter_event_latency_p95 + ) = 0 + else: + stream_duration_avg = stream_duration_p50 = stream_duration_p95 = ( + events_per_stream_avg + ) = 0 + inter_event_latency_avg = inter_event_latency_p50 = ( + inter_event_latency_p95 + ) = 0 + + # Also calculate overall average rates + total_elapsed = current_time - self.start_time + overall_conn_rate = ( + self.total_connections / total_elapsed if total_elapsed > 0 else 0 + ) + overall_event_rate = ( + self.total_events / total_elapsed if total_elapsed > 0 else 0 + ) + + return MetricsSnapshot( + active_connections=self.active_connections, + total_connections=self.total_connections, + total_events=self.total_events, + connection_rate=conn_rate, + event_rate=event_rate, + overall_conn_rate=overall_conn_rate, + overall_event_rate=overall_event_rate, + ttfe_avg=avg_ttfe, + ttfe_min=min_ttfe, + ttfe_max=max_ttfe, + ttfe_p50=p50_ttfe, + ttfe_p95=p95_ttfe, + ttfe_samples=len(self.ttfe_samples), + ttfe_total_samples=self.ttfe_total_count, # Return total count + error_counts=ErrorCounts(**self.error_counts), + stream_duration_avg=stream_duration_avg, + stream_duration_p50=stream_duration_p50, + stream_duration_p95=stream_duration_p95, + events_per_stream_avg=events_per_stream_avg, + inter_event_latency_avg=inter_event_latency_avg, + inter_event_latency_p50=inter_event_latency_p50, + inter_event_latency_p95=inter_event_latency_p95, + ) + + +# Global metrics instance +metrics = MetricsTracker() + + +class SSEParser: + """Parser for Server-Sent Events according to W3C spec""" + + def __init__(self) -> None: + self.data_buffer: list[str] = [] + self.event_type: str | None = None + self.event_id: str | None = None + + def parse_line(self, line: str) -> SSEEvent | None: + """Parse a single SSE line and return event if complete""" + # Empty line signals end of event + if not line: + if self.data_buffer: + event = SSEEvent( + data="\n".join(self.data_buffer), + event=self.event_type or "message", + id=self.event_id, + ) + self.data_buffer = [] + self.event_type = None + self.event_id = None + return event + return None + + # Comment line + if line.startswith(":"): + return None + + # Parse field + if ":" in line: + field, value = line.split(":", 1) + value = value.lstrip() + + if field == "data": + self.data_buffer.append(value) + elif field == "event": + self.event_type = value + elif field == "id": + self.event_id = value + + return None + + +# Note: SSEClient removed - we'll handle SSE parsing directly in the task for better Locust integration + + +class DifyWorkflowUser(HttpUser): + """Locust user for testing Dify workflow SSE endpoints""" + + # Use constant wait for streaming workloads + wait_time = constant(0) if os.getenv("WAIT_TIME", "0") == "0" else between(1, 3) + + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) # type: ignore[arg-type] + + # Load API configuration + config_helper = ConfigHelper() + self.api_token = config_helper.get_api_key() + + if not self.api_token: + raise ValueError("API key not found. Please run setup_all.py first.") + + # Load questions from file or use defaults + if QUESTIONS_FILE and os.path.exists(QUESTIONS_FILE): + with open(QUESTIONS_FILE, "r") as f: + self.questions = [line.strip() for line in f if line.strip()] + else: + self.questions = [ + "What is artificial intelligence?", + "Explain quantum computing", + "What is machine learning?", + "How do neural networks work?", + "What is renewable energy?", + ] + + self.user_counter = 0 + + def on_start(self) -> None: + """Called when a user starts""" + self.user_counter = 0 + + @task + def test_workflow_stream(self) -> None: + """Test workflow SSE streaming endpoint""" + + question = random.choice(self.questions) + self.user_counter += 1 + + headers = { + "Authorization": f"Bearer {self.api_token}", + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + data = WorkflowRequestData( + inputs=WorkflowInputs(question=question), + response_mode="streaming", + user=f"user_{self.user_counter}", + ) + + start_time = time.time() + first_event_time = None + event_count = 0 + inter_event_times: list[float] = [] + last_event_time = None + ttfe = 0 + request_success = False + bytes_received = 0 + + metrics.connection_started() + + # Use catch_response context manager directly + with self.client.request( + method="POST", + url=WORKFLOW_PATH, + headers=headers, + json=data, + stream=True, + catch_response=True, + timeout=(CONNECT_TIMEOUT, READ_TIMEOUT), + name="/v1/workflows/run", # Name for Locust stats + ) as response: + try: + # Validate response + if response.status_code >= 400: + error_type: ErrorType = ( + "http_4xx" if response.status_code < 500 else "http_5xx" + ) + metrics.record_error(error_type) + response.failure(f"HTTP {response.status_code}") + return + + content_type = response.headers.get("Content-Type", "") + if ( + "text/event-stream" not in content_type + and "application/json" not in content_type + ): + logger.error(f"Expected text/event-stream, got: {content_type}") + metrics.record_error("invalid_response") + response.failure(f"Invalid content type: {content_type}") + return + + # Parse SSE events + parser = SSEParser() + + for line in response.iter_lines(decode_unicode=True): + # Check if runner is stopping + if getattr(self.environment.runner, 'state', '') in ('stopping', 'stopped'): + logger.debug("Runner stopping, breaking streaming loop") + break + + if line is not None: + bytes_received += len(line.encode("utf-8")) + + # Parse SSE line + event = parser.parse_line(line if line is not None else "") + if event: + event_count += 1 + current_time = time.time() + metrics.event_received() + + # Track inter-event timing + if last_event_time: + inter_event_times.append( + (current_time - last_event_time) * 1000 + ) + last_event_time = current_time + + if first_event_time is None: + first_event_time = current_time + ttfe = (first_event_time - start_time) * 1000 + metrics.record_ttfe(ttfe) + + try: + # Parse event data + event_data = event.get("data", "") + if event_data: + if event_data == "[DONE]": + logger.debug("Received [DONE] sentinel") + request_success = True + break + + try: + parsed_event: ParsedEventData = json.loads(event_data) + # Check for terminal events + if parsed_event.get("event") in TERMINAL_EVENTS: + logger.debug( + f"Received terminal event: {parsed_event.get('event')}" + ) + request_success = True + break + except json.JSONDecodeError as e: + logger.debug( + f"JSON decode error: {e} for data: {event_data[:100]}" + ) + metrics.record_error("invalid_json") + + except Exception as e: + logger.error(f"Error processing event: {e}") + + # Mark success only if terminal condition was met or events were received + if request_success: + response.success() + elif event_count > 0: + # Got events but no proper terminal condition + metrics.record_error("early_termination") + response.failure("Stream ended without terminal event") + else: + response.failure("No events received") + + except ( + requests.exceptions.ConnectTimeout, + requests.exceptions.ReadTimeout, + ) as e: + metrics.record_error("timeout") + response.failure(f"Timeout: {e}") + except ( + requests.exceptions.ConnectionError, + requests.exceptions.RequestException, + ) as e: + metrics.record_error("connection_error") + response.failure(f"Connection error: {e}") + except Exception as e: + response.failure(str(e)) + raise + finally: + metrics.connection_ended() + + # Record stream metrics + if event_count > 0: + stream_duration = (time.time() - start_time) * 1000 + stream_metrics = StreamMetrics( + stream_duration=stream_duration, + events_count=event_count, + bytes_received=bytes_received, + ttfe=ttfe, + inter_event_times=inter_event_times, + ) + metrics.record_stream_metrics(stream_metrics) + logger.debug( + f"Stream completed: {event_count} events, {stream_duration:.1f}ms, success={request_success}" + ) + else: + logger.warning("No events received in stream") + + +# Event handlers +@events.test_start.add_listener # type: ignore[misc] +def on_test_start(environment: object, **kwargs: object) -> None: + logger.info("=" * 80) + logger.info(" " * 25 + "DIFY SSE BENCHMARK - REAL-TIME METRICS") + logger.info("=" * 80) + logger.info(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + logger.info("=" * 80) + + # Periodic stats reporting + def report_stats() -> None: + if not hasattr(environment, 'runner'): + return + runner = environment.runner + while hasattr(runner, 'state') and runner.state not in ["stopped", "stopping"]: + time.sleep(5) # Report every 5 seconds + if hasattr(runner, 'state') and runner.state == "running": + stats = metrics.get_stats() + + # Only log on master node in distributed mode + is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, 'runner') else True + if is_master: + # Clear previous lines and show updated stats + logger.info("\n" + "=" * 80) + logger.info( + f"{'METRIC':<25} {'CURRENT':>15} {'RATE (10s)':>15} {'AVG (overall)':>15} {'TOTAL':>12}" + ) + logger.info("-" * 80) + + # Active SSE Connections + logger.info( + f"{'Active SSE Connections':<25} {stats.active_connections:>15,d} {'-':>15} {'-':>12} {'-':>12}" + ) + + # New Connection Rate + logger.info( + f"{'New Connections':<25} {'-':>15} {stats.connection_rate:>13.2f}/s {stats.overall_conn_rate:>13.2f}/s {stats.total_connections:>12,d}" + ) + + # Event Throughput + logger.info( + f"{'Event Throughput':<25} {'-':>15} {stats.event_rate:>13.2f}/s {stats.overall_event_rate:>13.2f}/s {stats.total_events:>12,d}" + ) + + logger.info("-" * 80) + logger.info( + f"{'TIME TO FIRST EVENT':<25} {'AVG':>15} {'P50':>10} {'P95':>10} {'MIN':>10} {'MAX':>10}" + ) + logger.info( + f"{'(TTFE in ms)':<25} {stats.ttfe_avg:>15.1f} {stats.ttfe_p50:>10.1f} {stats.ttfe_p95:>10.1f} {stats.ttfe_min:>10.1f} {stats.ttfe_max:>10.1f}" + ) + logger.info(f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)") + logger.info(f"{'Total Samples':<25} {stats.ttfe_total_samples:>15,d}") + + # Inter-event latency + if stats.inter_event_latency_avg > 0: + logger.info("-" * 80) + logger.info( + f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}" + ) + logger.info( + f"{'(ms between events)':<25} {stats.inter_event_latency_avg:>15.1f} {stats.inter_event_latency_p50:>10.1f} {stats.inter_event_latency_p95:>10.1f}" + ) + + # Error stats + if any(stats.error_counts.values()): + logger.info("-" * 80) + logger.info(f"{'ERROR TYPE':<25} {'COUNT':>15}") + for error_type, count in stats.error_counts.items(): + if isinstance(count, int) and count > 0: + logger.info(f"{error_type:<25} {count:>15,d}") + + logger.info("=" * 80) + + # Show Locust stats summary + if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'): + total = environment.stats.total + if hasattr(total, 'num_requests') and total.num_requests > 0: + logger.info( + f"{'LOCUST STATS':<25} {'Requests':>12} {'Fails':>8} {'Avg (ms)':>12} {'Min':>8} {'Max':>8}" + ) + logger.info("-" * 80) + logger.info( + f"{'Aggregated':<25} {total.num_requests:>12,d} " + f"{total.num_failures:>8,d} " + f"{total.avg_response_time:>12.1f} " + f"{total.min_response_time:>8.0f} " + f"{total.max_response_time:>8.0f}" + ) + logger.info("=" * 80) + + threading.Thread(target=report_stats, daemon=True).start() + + +@events.test_stop.add_listener # type: ignore[misc] +def on_test_stop(environment: object, **kwargs: object) -> None: + stats = metrics.get_stats() + test_duration = time.time() - metrics.start_time + + # Log final results + logger.info("\n" + "=" * 80) + logger.info(" " * 30 + "FINAL BENCHMARK RESULTS") + logger.info("=" * 80) + logger.info(f"Test Duration: {test_duration:.1f} seconds") + logger.info("-" * 80) + + logger.info("") + logger.info("CONNECTIONS") + logger.info(f" {'Total Connections:':<30} {stats.total_connections:>10,d}") + logger.info(f" {'Final Active:':<30} {stats.active_connections:>10,d}") + logger.info(f" {'Average Rate:':<30} {stats.overall_conn_rate:>10.2f} conn/s") + + logger.info("") + logger.info("EVENTS") + logger.info(f" {'Total Events Received:':<30} {stats.total_events:>10,d}") + logger.info( + f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s" + ) + logger.info( + f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s" + ) + + logger.info("") + logger.info("STREAM METRICS") + logger.info(f" {'Avg Stream Duration:':<30} {stats.stream_duration_avg:>10.1f} ms") + logger.info(f" {'P50 Stream Duration:':<30} {stats.stream_duration_p50:>10.1f} ms") + logger.info(f" {'P95 Stream Duration:':<30} {stats.stream_duration_p95:>10.1f} ms") + logger.info( + f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}" + ) + + logger.info("") + logger.info("INTER-EVENT LATENCY") + logger.info(f" {'Average:':<30} {stats.inter_event_latency_avg:>10.1f} ms") + logger.info(f" {'Median (P50):':<30} {stats.inter_event_latency_p50:>10.1f} ms") + logger.info(f" {'95th Percentile:':<30} {stats.inter_event_latency_p95:>10.1f} ms") + + logger.info("") + logger.info("TIME TO FIRST EVENT (ms)") + logger.info(f" {'Average:':<30} {stats.ttfe_avg:>10.1f} ms") + logger.info(f" {'Median (P50):':<30} {stats.ttfe_p50:>10.1f} ms") + logger.info(f" {'95th Percentile:':<30} {stats.ttfe_p95:>10.1f} ms") + logger.info(f" {'Minimum:':<30} {stats.ttfe_min:>10.1f} ms") + logger.info(f" {'Maximum:':<30} {stats.ttfe_max:>10.1f} ms") + logger.info(f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})") + logger.info(f" {'Total Samples:':<30} {stats.ttfe_total_samples:>10,d}") + + # Error summary + if any(stats.error_counts.values()): + logger.info("") + logger.info("ERRORS") + for error_type, count in stats.error_counts.items(): + if isinstance(count, int) and count > 0: + logger.info(f" {error_type:<30} {count:>10,d}") + + logger.info("=" * 80 + "\n") + + # Export machine-readable report (only on master node) + is_master = not getattr(environment.runner, 'worker_id', None) if hasattr(environment, 'runner') else True + if is_master: + export_json_report(stats, test_duration, environment) + + +def export_json_report(stats: MetricsSnapshot, duration: float, environment: object) -> None: + """Export metrics to JSON file for CI/CD analysis""" + + reports_dir = Path(__file__).parent / "reports" + reports_dir.mkdir(exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + report_file = reports_dir / f"sse_metrics_{timestamp}.json" + + # Access environment.stats.total attributes safely + locust_stats: LocustStats | None = None + if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'): + total = environment.stats.total + if hasattr(total, 'num_requests') and total.num_requests > 0: + locust_stats = LocustStats( + total_requests=total.num_requests, + total_failures=total.num_failures, + avg_response_time=total.avg_response_time, + min_response_time=total.min_response_time, + max_response_time=total.max_response_time, + ) + + report_data = ReportData( + timestamp=datetime.now().isoformat(), + duration_seconds=duration, + metrics=asdict(stats), # type: ignore[arg-type] + locust_stats=locust_stats, + ) + + with open(report_file, "w") as f: + json.dump(report_data, f, indent=2) + + logger.info(f"Exported metrics to {report_file}") diff --git a/sdks/nodejs-client/index.js b/sdks/nodejs-client/index.js index 0ba7bba8bb..3025cc2ab6 100644 --- a/sdks/nodejs-client/index.js +++ b/sdks/nodejs-client/index.js @@ -95,10 +95,9 @@ export class DifyClient { headerParams = {} ) { const headers = { - ...{ + Authorization: `Bearer ${this.apiKey}`, "Content-Type": "application/json", - }, ...headerParams }; diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index abd0e7ae29..f113d4f0a8 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -1,5 +1,5 @@ import json - +from typing import Literal import requests @@ -8,7 +8,7 @@ class DifyClient: self.api_key = api_key self.base_url = base_url - def _send_request(self, method, endpoint, json=None, params=None, stream=False): + def _send_request(self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False): headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", @@ -31,15 +31,15 @@ class DifyClient: return response - def message_feedback(self, message_id, rating, user): + def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): data = {"rating": rating, "user": user} return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - def get_application_parameters(self, user): + def get_application_parameters(self, user: str): params = {"user": user} return self._send_request("GET", "/parameters", params=params) - def file_upload(self, user, files): + def file_upload(self, user: str, files: dict): data = {"user": user} return self._send_request_with_files( "POST", "/files/upload", data=data, files=files @@ -49,13 +49,13 @@ class DifyClient: data = {"text": text, "user": user, "streaming": streaming} return self._send_request("POST", "/text-to-audio", json=data) - def get_meta(self, user): + def get_meta(self, user: str): params = {"user": user} return self._send_request("GET", "/meta", params=params) class CompletionClient(DifyClient): - def create_completion_message(self, inputs, response_mode, user, files=None): + def create_completion_message(self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None): data = { "inputs": inputs, "response_mode": response_mode, @@ -76,7 +76,7 @@ class ChatClient(DifyClient): inputs: dict, query: str, user: str, - response_mode: str = "blocking", + response_mode: Literal["blocking", "streaming"] = "blocking", conversation_id: str | None = None, files: dict | None = None, ): @@ -156,7 +156,7 @@ class ChatClient(DifyClient): class WorkflowClient(DifyClient): def run( - self, inputs: dict, response_mode: str = "streaming", user: str = "abc-123" + self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123" ): data = {"inputs": inputs, "response_mode": response_mode, "user": user} return self._send_request("POST", "/workflows/run", data) @@ -172,7 +172,7 @@ class WorkflowClient(DifyClient): class KnowledgeBaseClient(DifyClient): def __init__( self, - api_key, + api_key: str, base_url: str = "https://api.dify.ai/v1", dataset_id: str | None = None, ): @@ -241,7 +241,7 @@ class KnowledgeBaseClient(DifyClient): return self._send_request("POST", url, json=data, **kwargs) def update_document_by_text( - self, document_id, name, text, extra_params: dict | None = None, **kwargs + self, document_id: str, name: str, text: str, extra_params: dict | None = None, **kwargs ): """ Update a document by text. @@ -278,7 +278,7 @@ class KnowledgeBaseClient(DifyClient): return self._send_request("POST", url, json=data, **kwargs) def create_document_by_file( - self, file_path, original_document_id=None, extra_params: dict | None = None + self, file_path: str, original_document_id: str | None = None, extra_params: dict | None = None ): """ Create a document by file. @@ -320,7 +320,7 @@ class KnowledgeBaseClient(DifyClient): ) def update_document_by_file( - self, document_id, file_path, extra_params: dict | None = None + self, document_id: str, file_path: str, extra_params: dict | None = None ): """ Update a document by file. @@ -377,7 +377,7 @@ class KnowledgeBaseClient(DifyClient): url = f"/datasets/{self._get_dataset_id()}" return self._send_request("DELETE", url) - def delete_document(self, document_id): + def delete_document(self, document_id: str): """ Delete a document. @@ -409,7 +409,7 @@ class KnowledgeBaseClient(DifyClient): url = f"/datasets/{self._get_dataset_id()}/documents" return self._send_request("GET", url, params=params, **kwargs) - def add_segments(self, document_id, segments, **kwargs): + def add_segments(self, document_id: str, segments: list[dict], **kwargs): """ Add segments to a document. @@ -423,7 +423,7 @@ class KnowledgeBaseClient(DifyClient): def query_segments( self, - document_id, + document_id: str, keyword: str | None = None, status: str | None = None, **kwargs, @@ -445,7 +445,7 @@ class KnowledgeBaseClient(DifyClient): params.update(kwargs["params"]) return self._send_request("GET", url, params=params, **kwargs) - def delete_document_segment(self, document_id, segment_id): + def delete_document_segment(self, document_id: str, segment_id: str): """ Delete a segment from a document. @@ -456,7 +456,7 @@ class KnowledgeBaseClient(DifyClient): url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" return self._send_request("DELETE", url) - def update_document_segment(self, document_id, segment_id, segment_data, **kwargs): + def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs): """ Update a segment in a document. diff --git a/web/.oxlintrc.json b/web/.oxlintrc.json index 1bfcca58f5..57eddd34fb 100644 --- a/web/.oxlintrc.json +++ b/web/.oxlintrc.json @@ -45,7 +45,7 @@ "no-unassigned-vars": "warn", "no-unsafe-finally": "warn", "no-unsafe-negation": "warn", - "no-unsafe-optional-chaining": "warn", + "no-unsafe-optional-chaining": "error", "no-unused-labels": "warn", "no-unused-private-class-members": "warn", "no-unused-vars": "warn", diff --git a/web/__tests__/goto-anything/match-action.test.ts b/web/__tests__/goto-anything/match-action.test.ts new file mode 100644 index 0000000000..3df9c0d533 --- /dev/null +++ b/web/__tests__/goto-anything/match-action.test.ts @@ -0,0 +1,235 @@ +import type { ActionItem } from '../../app/components/goto-anything/actions/types' + +// Mock the entire actions module to avoid import issues +jest.mock('../../app/components/goto-anything/actions', () => ({ + matchAction: jest.fn(), +})) + +jest.mock('../../app/components/goto-anything/actions/commands/registry') + +// Import after mocking to get mocked version +import { matchAction } from '../../app/components/goto-anything/actions' +import { slashCommandRegistry } from '../../app/components/goto-anything/actions/commands/registry' + +// Implement the actual matchAction logic for testing +const actualMatchAction = (query: string, actions: Record) => { + const result = Object.values(actions).find((action) => { + // Special handling for slash commands + if (action.key === '/') { + // Get all registered commands from the registry + const allCommands = slashCommandRegistry.getAllCommands() + + // Check if query matches any registered command + return allCommands.some((cmd) => { + const cmdPattern = `/${cmd.name}` + + // For direct mode commands, don't match (keep in command selector) + if (cmd.mode === 'direct') + return false + + // For submenu mode commands, match when complete command is entered + return query === cmdPattern || query.startsWith(`${cmdPattern} `) + }) + } + + const reg = new RegExp(`^(${action.key}|${action.shortcut})(?:\\s|$)`) + return reg.test(query) + }) + return result +} + +// Replace mock with actual implementation +;(matchAction as jest.Mock).mockImplementation(actualMatchAction) + +describe('matchAction Logic', () => { + const mockActions: Record = { + app: { + key: '@app', + shortcut: '@a', + title: 'Search Applications', + description: 'Search apps', + search: jest.fn(), + }, + knowledge: { + key: '@knowledge', + shortcut: '@kb', + title: 'Search Knowledge', + description: 'Search knowledge bases', + search: jest.fn(), + }, + slash: { + key: '/', + shortcut: '/', + title: 'Commands', + description: 'Execute commands', + search: jest.fn(), + }, + } + + beforeEach(() => { + jest.clearAllMocks() + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'docs', mode: 'direct' }, + { name: 'community', mode: 'direct' }, + { name: 'feedback', mode: 'direct' }, + { name: 'account', mode: 'direct' }, + { name: 'theme', mode: 'submenu' }, + { name: 'language', mode: 'submenu' }, + ]) + }) + + describe('@ Actions Matching', () => { + it('should match @app with key', () => { + const result = matchAction('@app', mockActions) + expect(result).toBe(mockActions.app) + }) + + it('should match @app with shortcut', () => { + const result = matchAction('@a', mockActions) + expect(result).toBe(mockActions.app) + }) + + it('should match @knowledge with key', () => { + const result = matchAction('@knowledge', mockActions) + expect(result).toBe(mockActions.knowledge) + }) + + it('should match @knowledge with shortcut @kb', () => { + const result = matchAction('@kb', mockActions) + expect(result).toBe(mockActions.knowledge) + }) + + it('should match with text after action', () => { + const result = matchAction('@app search term', mockActions) + expect(result).toBe(mockActions.app) + }) + + it('should not match partial @ actions', () => { + const result = matchAction('@ap', mockActions) + expect(result).toBeUndefined() + }) + }) + + describe('Slash Commands Matching', () => { + describe('Direct Mode Commands', () => { + it('should not match direct mode commands', () => { + const result = matchAction('/docs', mockActions) + expect(result).toBeUndefined() + }) + + it('should not match direct mode with arguments', () => { + const result = matchAction('/docs something', mockActions) + expect(result).toBeUndefined() + }) + + it('should not match any direct mode command', () => { + expect(matchAction('/community', mockActions)).toBeUndefined() + expect(matchAction('/feedback', mockActions)).toBeUndefined() + expect(matchAction('/account', mockActions)).toBeUndefined() + }) + }) + + describe('Submenu Mode Commands', () => { + it('should match submenu mode commands exactly', () => { + const result = matchAction('/theme', mockActions) + expect(result).toBe(mockActions.slash) + }) + + it('should match submenu mode with arguments', () => { + const result = matchAction('/theme dark', mockActions) + expect(result).toBe(mockActions.slash) + }) + + it('should match all submenu commands', () => { + expect(matchAction('/language', mockActions)).toBe(mockActions.slash) + expect(matchAction('/language en', mockActions)).toBe(mockActions.slash) + }) + }) + + describe('Slash Without Command', () => { + it('should not match single slash', () => { + const result = matchAction('/', mockActions) + expect(result).toBeUndefined() + }) + + it('should not match unregistered commands', () => { + const result = matchAction('/unknown', mockActions) + expect(result).toBeUndefined() + }) + }) + }) + + describe('Edge Cases', () => { + it('should handle empty query', () => { + const result = matchAction('', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle whitespace only', () => { + const result = matchAction(' ', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle regular text without actions', () => { + const result = matchAction('search something', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle special characters', () => { + const result = matchAction('#tag', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle multiple @ or /', () => { + expect(matchAction('@@app', mockActions)).toBeUndefined() + expect(matchAction('//theme', mockActions)).toBeUndefined() + }) + }) + + describe('Mode-based Filtering', () => { + it('should filter direct mode commands from matching', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'test', mode: 'direct' }, + ]) + + const result = matchAction('/test', mockActions) + expect(result).toBeUndefined() + }) + + it('should allow submenu mode commands to match', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'test', mode: 'submenu' }, + ]) + + const result = matchAction('/test', mockActions) + expect(result).toBe(mockActions.slash) + }) + + it('should treat undefined mode as submenu', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'test' }, // No mode specified + ]) + + const result = matchAction('/test', mockActions) + expect(result).toBe(mockActions.slash) + }) + }) + + describe('Registry Integration', () => { + it('should call getAllCommands when matching slash', () => { + matchAction('/theme', mockActions) + expect(slashCommandRegistry.getAllCommands).toHaveBeenCalled() + }) + + it('should not call getAllCommands for @ actions', () => { + matchAction('@app', mockActions) + expect(slashCommandRegistry.getAllCommands).not.toHaveBeenCalled() + }) + + it('should handle empty command list', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([]) + const result = matchAction('/anything', mockActions) + expect(result).toBeUndefined() + }) + }) +}) diff --git a/web/__tests__/goto-anything/scope-command-tags.test.tsx b/web/__tests__/goto-anything/scope-command-tags.test.tsx new file mode 100644 index 0000000000..339e259a06 --- /dev/null +++ b/web/__tests__/goto-anything/scope-command-tags.test.tsx @@ -0,0 +1,134 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import '@testing-library/jest-dom' + +// Type alias for search mode +type SearchMode = 'scopes' | 'commands' | null + +// Mock component to test tag display logic +const TagDisplay: React.FC<{ searchMode: SearchMode }> = ({ searchMode }) => { + if (!searchMode) return null + + return ( +
+ {searchMode === 'scopes' ? 'SCOPES' : 'COMMANDS'} +
+ ) +} + +describe('Scope and Command Tags', () => { + describe('Tag Display Logic', () => { + it('should display SCOPES for @ actions', () => { + render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + expect(screen.queryByText('COMMANDS')).not.toBeInTheDocument() + }) + + it('should display COMMANDS for / actions', () => { + render() + expect(screen.getByText('COMMANDS')).toBeInTheDocument() + expect(screen.queryByText('SCOPES')).not.toBeInTheDocument() + }) + + it('should not display any tag when searchMode is null', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + }) + + describe('Search Mode Detection', () => { + const getSearchMode = (query: string): SearchMode => { + if (query.startsWith('@')) return 'scopes' + if (query.startsWith('/')) return 'commands' + return null + } + + it('should detect scopes mode for @ queries', () => { + expect(getSearchMode('@app')).toBe('scopes') + expect(getSearchMode('@knowledge')).toBe('scopes') + expect(getSearchMode('@plugin')).toBe('scopes') + expect(getSearchMode('@node')).toBe('scopes') + }) + + it('should detect commands mode for / queries', () => { + expect(getSearchMode('/theme')).toBe('commands') + expect(getSearchMode('/language')).toBe('commands') + expect(getSearchMode('/docs')).toBe('commands') + }) + + it('should return null for regular queries', () => { + expect(getSearchMode('')).toBe(null) + expect(getSearchMode('search term')).toBe(null) + expect(getSearchMode('app')).toBe(null) + }) + + it('should handle queries with spaces', () => { + expect(getSearchMode('@app search')).toBe('scopes') + expect(getSearchMode('/theme dark')).toBe('commands') + }) + }) + + describe('Tag Styling', () => { + it('should apply correct styling classes', () => { + const { container } = render() + const tagContainer = container.querySelector('.flex.items-center.gap-1.text-xs.text-text-tertiary') + expect(tagContainer).toBeInTheDocument() + }) + + it('should use hardcoded English text', () => { + // Verify that tags are hardcoded and not using i18n + render() + const scopesText = screen.getByText('SCOPES') + expect(scopesText.textContent).toBe('SCOPES') + + render() + const commandsText = screen.getByText('COMMANDS') + expect(commandsText.textContent).toBe('COMMANDS') + }) + }) + + describe('Integration with Search States', () => { + const SearchComponent: React.FC<{ query: string }> = ({ query }) => { + let searchMode: SearchMode = null + + if (query.startsWith('@')) searchMode = 'scopes' + else if (query.startsWith('/')) searchMode = 'commands' + + return ( +
+ + +
+ ) + } + + it('should update tag when switching between @ and /', () => { + const { rerender } = render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.queryByText('SCOPES')).not.toBeInTheDocument() + expect(screen.getByText('COMMANDS')).toBeInTheDocument() + }) + + it('should hide tag when clearing search', () => { + const { rerender } = render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.queryByText('SCOPES')).not.toBeInTheDocument() + expect(screen.queryByText('COMMANDS')).not.toBeInTheDocument() + }) + + it('should maintain correct tag during search refinement', () => { + const { rerender } = render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + }) + }) +}) diff --git a/web/__tests__/goto-anything/slash-command-modes.test.tsx b/web/__tests__/goto-anything/slash-command-modes.test.tsx new file mode 100644 index 0000000000..f8126958fc --- /dev/null +++ b/web/__tests__/goto-anything/slash-command-modes.test.tsx @@ -0,0 +1,212 @@ +import '@testing-library/jest-dom' +import { slashCommandRegistry } from '../../app/components/goto-anything/actions/commands/registry' +import type { SlashCommandHandler } from '../../app/components/goto-anything/actions/commands/types' + +// Mock the registry +jest.mock('../../app/components/goto-anything/actions/commands/registry') + +describe('Slash Command Dual-Mode System', () => { + const mockDirectCommand: SlashCommandHandler = { + name: 'docs', + description: 'Open documentation', + mode: 'direct', + execute: jest.fn(), + search: jest.fn().mockResolvedValue([ + { + id: 'docs', + title: 'Documentation', + description: 'Open documentation', + type: 'command' as const, + data: { command: 'navigation.docs', args: {} }, + }, + ]), + register: jest.fn(), + unregister: jest.fn(), + } + + const mockSubmenuCommand: SlashCommandHandler = { + name: 'theme', + description: 'Change theme', + mode: 'submenu', + search: jest.fn().mockResolvedValue([ + { + id: 'theme-light', + title: 'Light Theme', + description: 'Switch to light theme', + type: 'command' as const, + data: { command: 'theme.set', args: { theme: 'light' } }, + }, + { + id: 'theme-dark', + title: 'Dark Theme', + description: 'Switch to dark theme', + type: 'command' as const, + data: { command: 'theme.set', args: { theme: 'dark' } }, + }, + ]), + register: jest.fn(), + unregister: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + ;(slashCommandRegistry as any).findCommand = jest.fn((name: string) => { + if (name === 'docs') return mockDirectCommand + if (name === 'theme') return mockSubmenuCommand + return null + }) + ;(slashCommandRegistry as any).getAllCommands = jest.fn(() => [ + mockDirectCommand, + mockSubmenuCommand, + ]) + }) + + describe('Direct Mode Commands', () => { + it('should execute immediately when selected', () => { + const mockSetShow = jest.fn() + const mockSetSearchQuery = jest.fn() + + // Simulate command selection + const handler = slashCommandRegistry.findCommand('docs') + expect(handler?.mode).toBe('direct') + + if (handler?.mode === 'direct' && handler.execute) { + handler.execute() + mockSetShow(false) + mockSetSearchQuery('') + } + + expect(mockDirectCommand.execute).toHaveBeenCalled() + expect(mockSetShow).toHaveBeenCalledWith(false) + expect(mockSetSearchQuery).toHaveBeenCalledWith('') + }) + + it('should not enter submenu for direct mode commands', () => { + const handler = slashCommandRegistry.findCommand('docs') + expect(handler?.mode).toBe('direct') + expect(handler?.execute).toBeDefined() + }) + + it('should close modal after execution', () => { + const mockModalClose = jest.fn() + + const handler = slashCommandRegistry.findCommand('docs') + if (handler?.mode === 'direct' && handler.execute) { + handler.execute() + mockModalClose() + } + + expect(mockModalClose).toHaveBeenCalled() + }) + }) + + describe('Submenu Mode Commands', () => { + it('should show options instead of executing immediately', async () => { + const handler = slashCommandRegistry.findCommand('theme') + expect(handler?.mode).toBe('submenu') + + const results = await handler?.search('', 'en') + expect(results).toHaveLength(2) + expect(results?.[0].title).toBe('Light Theme') + expect(results?.[1].title).toBe('Dark Theme') + }) + + it('should not have execute function for submenu mode', () => { + const handler = slashCommandRegistry.findCommand('theme') + expect(handler?.mode).toBe('submenu') + expect(handler?.execute).toBeUndefined() + }) + + it('should keep modal open for selection', () => { + const mockModalClose = jest.fn() + + const handler = slashCommandRegistry.findCommand('theme') + // For submenu mode, modal should not close immediately + expect(handler?.mode).toBe('submenu') + expect(mockModalClose).not.toHaveBeenCalled() + }) + }) + + describe('Mode Detection and Routing', () => { + it('should correctly identify direct mode commands', () => { + const commands = slashCommandRegistry.getAllCommands() + const directCommands = commands.filter(cmd => cmd.mode === 'direct') + const submenuCommands = commands.filter(cmd => cmd.mode === 'submenu') + + expect(directCommands).toContainEqual(expect.objectContaining({ name: 'docs' })) + expect(submenuCommands).toContainEqual(expect.objectContaining({ name: 'theme' })) + }) + + it('should handle missing mode property gracefully', () => { + const commandWithoutMode: SlashCommandHandler = { + name: 'test', + description: 'Test command', + search: jest.fn(), + register: jest.fn(), + unregister: jest.fn(), + } + + ;(slashCommandRegistry as any).findCommand = jest.fn(() => commandWithoutMode) + + const handler = slashCommandRegistry.findCommand('test') + // Default behavior should be submenu when mode is not specified + expect(handler?.mode).toBeUndefined() + expect(handler?.execute).toBeUndefined() + }) + }) + + describe('Enter Key Handling', () => { + // Helper function to simulate key handler behavior + const createKeyHandler = () => { + return (commandKey: string) => { + if (commandKey.startsWith('/')) { + const commandName = commandKey.substring(1) + const handler = slashCommandRegistry.findCommand(commandName) + if (handler?.mode === 'direct' && handler.execute) { + handler.execute() + return true // Indicates handled + } + } + return false + } + } + + it('should trigger direct execution on Enter for direct mode', () => { + const keyHandler = createKeyHandler() + const handled = keyHandler('/docs') + expect(handled).toBe(true) + expect(mockDirectCommand.execute).toHaveBeenCalled() + }) + + it('should not trigger direct execution for submenu mode', () => { + const keyHandler = createKeyHandler() + const handled = keyHandler('/theme') + expect(handled).toBe(false) + expect(mockSubmenuCommand.search).not.toHaveBeenCalled() + }) + }) + + describe('Command Registration', () => { + it('should register both direct and submenu commands', () => { + mockDirectCommand.register?.({}) + mockSubmenuCommand.register?.({ setTheme: jest.fn() }) + + expect(mockDirectCommand.register).toHaveBeenCalled() + expect(mockSubmenuCommand.register).toHaveBeenCalled() + }) + + it('should handle unregistration for both command types', () => { + // Test unregister for direct command + mockDirectCommand.unregister?.() + expect(mockDirectCommand.unregister).toHaveBeenCalled() + + // Test unregister for submenu command + mockSubmenuCommand.unregister?.() + expect(mockSubmenuCommand.unregister).toHaveBeenCalled() + + // Verify both were called independently + expect(mockDirectCommand.unregister).toHaveBeenCalledTimes(1) + expect(mockSubmenuCommand.unregister).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index f8189b0c8a..6d72e957e3 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -129,7 +129,7 @@ const DatasetDetailLayout: FC = (props) => { params: { datasetId }, } = props const pathname = usePathname() - const hideSideBar = /documents\/create$/.test(pathname) + const hideSideBar = pathname.endsWith('documents/create') const { t } = useTranslation() const { isCurrentWorkspaceDatasetOperator } = useAppContext() diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index 0d41691dfd..ccbc73aef0 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -1949,57 +1949,6 @@ ___
- - - - ### Path - - - Knowledge ID - - - Document ID - - - - - - ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/upload-file' \ - --header 'Authorization: Bearer {api_key}' \ - --header 'Content-Type: application/json' - ``` - - - ```json {{ title: 'Response' }} - { - "id": "file_id", - "name": "file_name", - "size": 1024, - "extension": "txt", - "url": "preview_url", - "download_url": "download_url", - "mime_type": "text/plain", - "created_by": "user_id", - "created_at": 1728734540, - } - ``` - - - - -
- - - - - ### パス - - - ナレッジ ID - - - ドキュメント ID - - - - - - ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/upload-file' \ - --header 'Authorization: Bearer {api_key}' \ - --header 'Content-Type: application/json' - ``` - - - ```json {{ title: 'Response' }} - { - "id": "file_id", - "name": "file_name", - "size": 1024, - "extension": "txt", - "url": "preview_url", - "download_url": "download_url", - "mime_type": "text/plain", - "created_by": "user_id", - "created_at": 1728734540, - } - ``` - - - - -
-
- diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index b7ea889a46..1971d9ff84 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -1991,57 +1991,6 @@ ___
- - - - ### Path - - - 知识库 ID - - - 文档 ID - - - - - - ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/upload-file' \ - --header 'Authorization: Bearer {api_key}' \ - --header 'Content-Type: application/json' - ``` - - - ```json {{ title: 'Response' }} - { - "id": "file_id", - "name": "file_name", - "size": 1024, - "extension": "txt", - "url": "preview_url", - "download_url": "download_url", - "mime_type": "text/plain", - "created_by": "user_id", - "created_at": 1728734540, - } - ``` - - - - -
- setEmail(e.target.value)} />
- +
diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 5890c2ea92..f3dbc9421c 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -43,9 +43,9 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { const handleSaveAvatar = useCallback(async (uploadedFileId: string) => { try { await updateUserProfile({ url: 'account/avatar', body: { avatar: uploadedFileId } }) - notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) setIsShowAvatarPicker(false) onSave?.() + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) } catch (e) { notify({ type: 'error', message: (e as Error).message }) diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 2037647b99..dc13d59f2b 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -319,7 +319,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx background={appDetail.icon_background} imageUrl={appDetail.icon_url} /> -
+
{appDetail.name}
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
diff --git a/web/app/components/app/configuration/config-vision/index.tsx b/web/app/components/app/configuration/config-vision/index.tsx index f719822bf9..f0904b3fd8 100644 --- a/web/app/components/app/configuration/config-vision/index.tsx +++ b/web/app/components/app/configuration/config-vision/index.tsx @@ -45,7 +45,7 @@ const ConfigVision: FC = () => { if (draft.file) { draft.file.enabled = (draft.file.allowed_file_types?.length ?? 0) > 0 draft.file.image = { - ...(draft.file.image || {}), + ...draft.file.image, enabled: value, } } diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index e6b6c83846..5c87eea3ca 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -50,6 +50,7 @@ export type IGetAutomaticResProps = { onFinished: (res: GenRes) => void flowId?: string nodeId?: string + editorId?: string currentPrompt?: string isBasicMode?: boolean } @@ -76,6 +77,7 @@ const GetAutomaticRes: FC = ({ onClose, flowId, nodeId, + editorId, currentPrompt, isBasicMode, onFinished, @@ -132,7 +134,8 @@ const GetAutomaticRes: FC = ({ }, ] - const [instructionFromSessionStorage, setInstruction] = useSessionStorageState(`improve-instruction-${flowId}${isBasicMode ? '' : `-${nodeId}`}`) + // eslint-disable-next-line sonarjs/no-nested-template-literals, sonarjs/no-nested-conditional + const [instructionFromSessionStorage, setInstruction] = useSessionStorageState(`improve-instruction-${flowId}${isBasicMode ? '' : `-${nodeId}${editorId ? `-${editorId}` : ''}`}`) const instruction = instructionFromSessionStorage || '' const [ideaOutput, setIdeaOutput] = useState('') @@ -166,7 +169,7 @@ const GetAutomaticRes: FC = ({ return true } const [isLoading, { setTrue: setLoadingTrue, setFalse: setLoadingFalse }] = useBoolean(false) - const storageKey = `${flowId}${isBasicMode ? '' : `-${nodeId}`}` + const storageKey = `${flowId}${isBasicMode ? '' : `-${nodeId}${editorId ? `-${editorId}` : ''}`}` const { addVersion, current, currentVersionIndex, setCurrentVersionIndex, versions } = useGenData({ storageKey, }) diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 70a45a4bbe..cd73874c2c 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -1,6 +1,6 @@ 'use client' -import { useCallback, useRef, useState } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' @@ -35,14 +35,15 @@ type CreateAppProps = { onSuccess: () => void onClose: () => void onCreateFromTemplate?: () => void + defaultAppMode?: AppMode } -function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) { +function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: CreateAppProps) { const { t } = useTranslation() const { push } = useRouter() const { notify } = useContext(ToastContext) - const [appMode, setAppMode] = useState('advanced-chat') + const [appMode, setAppMode] = useState(defaultAppMode || 'advanced-chat') const [appIcon, setAppIcon] = useState({ type: 'emoji', icon: '🤖', background: '#FFEAD5' }) const [showAppIconPicker, setShowAppIconPicker] = useState(false) const [name, setName] = useState('') @@ -55,6 +56,11 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) const isCreatingRef = useRef(false) + useEffect(() => { + if (appMode === 'chat' || appMode === 'agent-chat' || appMode === 'completion') + setIsAppTypeExpanded(true) + }, [appMode]) + const onCreate = useCallback(async () => { if (!appMode) { notify({ type: 'error', message: t('app.newApp.appTypeRequired') }) @@ -264,7 +270,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) type CreateAppDialogProps = CreateAppProps & { show: boolean } -const CreateAppModal = ({ show, onClose, onSuccess, onCreateFromTemplate }: CreateAppDialogProps) => { +const CreateAppModal = ({ show, onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: CreateAppDialogProps) => { return ( - + ) } diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index d0d42dc32c..e9a64d8867 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -279,12 +279,21 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { )} { - (isGettingUserCanAccessApp || !userCanAccessApp?.result) ? null : <> - - - + (!systemFeatures.webapp_auth.enabled) + ? <> + + + + : !(isGettingUserCanAccessApp || !userCanAccessApp?.result) && ( + <> + + + + ) } { diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index aa85fb1313..4ee9a6d6d5 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -211,14 +211,14 @@ const List = () => { {(data && data[0].total > 0) ?
{isCurrentWorkspaceEditor - && } + && } {data.map(({ data: apps }) => apps.map(app => ( )))}
:
{isCurrentWorkspaceEditor - && } + && }
} diff --git a/web/app/components/apps/new-app-card.tsx b/web/app/components/apps/new-app-card.tsx index 451d2ae326..6ceeb47982 100644 --- a/web/app/components/apps/new-app-card.tsx +++ b/web/app/components/apps/new-app-card.tsx @@ -26,12 +26,14 @@ export type CreateAppCardProps = { className?: string onSuccess?: () => void ref: React.RefObject + selectedAppType?: string } const CreateAppCard = ({ ref, className, onSuccess, + selectedAppType, }: CreateAppCardProps) => { const { t } = useTranslation() const { onPlanInfoChanged } = useProviderContext() @@ -86,6 +88,7 @@ const CreateAppCard = ({ setShowNewAppTemplateDialog(true) setShowNewAppModal(false) }} + defaultAppMode={selectedAppType !== 'all' ? selectedAppType as any : undefined} /> )} {showNewAppTemplateDialog && ( diff --git a/web/app/components/base/agent-log-modal/tool-call.tsx b/web/app/components/base/agent-log-modal/tool-call.tsx index 499a70367c..433a20fd5d 100644 --- a/web/app/components/base/agent-log-modal/tool-call.tsx +++ b/web/app/components/base/agent-log-modal/tool-call.tsx @@ -33,7 +33,7 @@ const ToolCallItem: FC = ({ toolCall, isLLM = false, isFinal, tokens, obs if (time < 1) return `${(time * 1000).toFixed(3)} ms` if (time > 60) - return `${Number.parseInt(Math.round(time / 60).toString())} m ${(time % 60).toFixed(3)} s` + return `${Math.floor(time / 60)} m ${(time % 60).toFixed(3)} s` return `${time.toFixed(3)} s` } diff --git a/web/app/components/base/avatar/index.tsx b/web/app/components/base/avatar/index.tsx index a6e04a0755..89019a19b0 100644 --- a/web/app/components/base/avatar/index.tsx +++ b/web/app/components/base/avatar/index.tsx @@ -1,5 +1,5 @@ 'use client' -import { useState } from 'react' +import { useEffect, useState } from 'react' import cn from '@/utils/classnames' export type AvatarProps = { @@ -27,6 +27,12 @@ const Avatar = ({ onError?.(true) } + // after uploaded, api would first return error imgs url: '.../files//file-preview/...'. Then return the right url, Which caused not show the avatar + useEffect(() => { + if(avatar && imgError) + setImgError(false) + }, [avatar]) + if (avatar && !imgError) { return ( (null) const isInitial = useRef(true) - const inputValue = useRef(value ? value.tz(timezone) : undefined).current + + // Normalize the value to ensure that all subsequent uses are Day.js objects. + const normalizedValue = useMemo(() => { + if (!value) return undefined + return dayjs.isDayjs(value) ? value.tz(timezone) : dayjs(value).tz(timezone) + }, [value, timezone]) + + const inputValue = useRef(normalizedValue).current const defaultValue = useRef(getDateWithTimezone({ timezone })).current const [currentDate, setCurrentDate] = useState(inputValue || defaultValue) @@ -68,8 +75,8 @@ const DatePicker = ({ return } clearMonthMapCache() - if (value) { - const newValue = getDateWithTimezone({ date: value, timezone }) + if (normalizedValue) { + const newValue = getDateWithTimezone({ date: normalizedValue, timezone }) setCurrentDate(newValue) setSelectedDate(newValue) onChange(newValue) @@ -88,9 +95,9 @@ const DatePicker = ({ } setView(ViewType.date) setIsOpen(true) - if (value) { - setCurrentDate(value) - setSelectedDate(value) + if (normalizedValue) { + setCurrentDate(normalizedValue) + setSelectedDate(normalizedValue) } } @@ -192,7 +199,7 @@ const DatePicker = ({ } const timeFormat = needTimePicker ? t('time.dateFormats.displayWithTime') : t('time.dateFormats.display') - const displayValue = value?.format(timeFormat) || '' + const displayValue = normalizedValue?.format(timeFormat) || '' const displayTime = selectedDate?.format('hh:mm A') || '--:-- --' const placeholderDate = isOpen && selectedDate ? selectedDate.format(timeFormat) : (placeholder || t('time.defaultPlaceholder')) @@ -204,7 +211,7 @@ const DatePicker = ({ > {renderTrigger ? (renderTrigger({ - value, + value: normalizedValue, selectedDate, isOpen, handleClear, diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx index 53db991e71..ec8681f37c 100644 --- a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx @@ -83,9 +83,7 @@ const OpeningSettingModal = ({ }, [handleSave, hideConfirmAddVar]) const autoAddVar = useCallback(() => { - onAutoAddPromptVariable?.([ - ...notIncludeKeys.map(key => getNewVar(key, 'string')), - ]) + onAutoAddPromptVariable?.(notIncludeKeys.map(key => getNewVar(key, 'string'))) hideConfirmAddVar() handleSave(true) }, [handleSave, hideConfirmAddVar, notIncludeKeys, onAutoAddPromptVariable]) diff --git a/web/app/components/base/ga/index.tsx b/web/app/components/base/ga/index.tsx index 7a95561754..81d84a85d3 100644 --- a/web/app/components/base/ga/index.tsx +++ b/web/app/components/base/ga/index.tsx @@ -24,7 +24,7 @@ const GA: FC = ({ if (IS_CE_EDITION) return null - const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') : '' + const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : '' return ( <> @@ -32,7 +32,7 @@ const GA: FC = ({ strategy="beforeInteractive" async src={`https://www.googletagmanager.com/gtag/js?id=${gaIdMaps[gaType]}`} - nonce={nonce!} + nonce={nonce ?? undefined} > {/* Cookie banner */} diff --git a/web/app/components/base/markdown-blocks/link.tsx b/web/app/components/base/markdown-blocks/link.tsx index 0274ee0141..9bf13040a7 100644 --- a/web/app/components/base/markdown-blocks/link.tsx +++ b/web/app/components/base/markdown-blocks/link.tsx @@ -17,7 +17,7 @@ const Link = ({ node, children, ...props }: any) => { } else { const href = props.href || node.properties?.href - if (href && /^#[a-zA-Z0-9_\-]+$/.test(href.toString())) { + if (href && /^#[a-zA-Z0-9_-]+$/.test(href.toString())) { const handleClick = (e: React.MouseEvent) => { e.preventDefault() // scroll to target element if exists within the answer container diff --git a/web/app/components/base/markdown-blocks/think-block.tsx b/web/app/components/base/markdown-blocks/think-block.tsx index 46f992d758..a5813266f1 100644 --- a/web/app/components/base/markdown-blocks/think-block.tsx +++ b/web/app/components/base/markdown-blocks/think-block.tsx @@ -1,5 +1,6 @@ import React, { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import { useChatContext } from '../chat/chat/context' const hasEndThink = (children: any): boolean => { if (typeof children === 'string') @@ -35,6 +36,7 @@ const removeEndThink = (children: any): any => { } const useThinkTimer = (children: any) => { + const { isResponding } = useChatContext() const [startTime] = useState(Date.now()) const [elapsedTime, setElapsedTime] = useState(0) const [isComplete, setIsComplete] = useState(false) @@ -54,9 +56,9 @@ const useThinkTimer = (children: any) => { }, [startTime, isComplete]) useEffect(() => { - if (hasEndThink(children)) + if (hasEndThink(children) || !isResponding) setIsComplete(true) - }, [children]) + }, [children, isResponding]) return { elapsedTime, isComplete } } diff --git a/web/app/components/base/notion-page-selector/page-selector/index.tsx b/web/app/components/base/notion-page-selector/page-selector/index.tsx index 498955994c..a1ce3116db 100644 --- a/web/app/components/base/notion-page-selector/page-selector/index.tsx +++ b/web/app/components/base/notion-page-selector/page-selector/index.tsx @@ -229,7 +229,7 @@ const PageSelector = ({ if (current.expand) { current.expand = false - newDataList = [...dataList.filter(item => !descendantsIds.includes(item.page_id))] + newDataList = dataList.filter(item => !descendantsIds.includes(item.page_id)) } else { current.expand = true @@ -246,7 +246,7 @@ const PageSelector = ({ setDataList(newDataList) } - const copyValue = new Set([...value]) + const copyValue = new Set(value) const handleCheck = (index: number) => { const current = currentDataList[index] const pageId = current.page_id @@ -269,7 +269,7 @@ const PageSelector = ({ copyValue.add(pageId) } - onSelect(new Set([...copyValue])) + onSelect(new Set(copyValue)) } const handlePreview = (index: number) => { diff --git a/web/app/components/base/param-item/score-threshold-item.tsx b/web/app/components/base/param-item/score-threshold-item.tsx index b5557c80cf..3790a2a074 100644 --- a/web/app/components/base/param-item/score-threshold-item.tsx +++ b/web/app/components/base/param-item/score-threshold-item.tsx @@ -20,7 +20,6 @@ const VALUE_LIMIT = { max: 1, } -const key = 'score_threshold' const ScoreThresholdItem: FC = ({ className, value, @@ -39,9 +38,9 @@ const ScoreThresholdItem: FC = ({ return ( = ({ className, value, @@ -41,9 +40,9 @@ const TopKItem: FC = ({ return ( ), ) diff --git a/web/app/components/base/zendesk/index.tsx b/web/app/components/base/zendesk/index.tsx new file mode 100644 index 0000000000..b3d67eb390 --- /dev/null +++ b/web/app/components/base/zendesk/index.tsx @@ -0,0 +1,21 @@ +import { memo } from 'react' +import { type UnsafeUnwrappedHeaders, headers } from 'next/headers' +import Script from 'next/script' +import { IS_CE_EDITION, ZENDESK_WIDGET_KEY } from '@/config' + +const Zendesk = () => { + if (IS_CE_EDITION || !ZENDESK_WIDGET_KEY) + return null + + const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : '' + + return ( +