diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 93ecac48f2..022f71bfb4 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -1,6 +1,6 @@ #!/bin/bash -npm add -g pnpm@10.11.1 +npm add -g pnpm@10.13.1 cd web && pnpm install pipx install uv @@ -12,3 +12,4 @@ echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f do echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc source /home/vscode/.bashrc + diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index b06ab9653e..a283f8d5ca 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -28,7 +28,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: | api/** @@ -75,7 +75,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: web/** @@ -113,7 +113,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: | docker/generate_docker_compose @@ -144,7 +144,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: | **.sh @@ -152,13 +152,15 @@ jobs: **.yml **Dockerfile dev/** + .editorconfig - name: Super-linter - uses: super-linter/super-linter/slim@v7 + uses: super-linter/super-linter/slim@v8 if: steps.changed-files.outputs.any_changed == 'true' env: BASH_SEVERITY: warning - DEFAULT_BRANCH: main + DEFAULT_BRANCH: origin/main + EDITORCONFIG_FILE_NAME: editorconfig-checker.json FILTER_REGEX_INCLUDE: pnpm-lock.yaml GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} IGNORE_GENERATED_FILES: true @@ -168,16 +170,6 @@ jobs: # FIXME: temporarily disabled until api-docker.yaml's run script is fixed for shellcheck # VALIDATE_GITHUB_ACTIONS: true VALIDATE_DOCKERFILE_HADOLINT: true + VALIDATE_EDITORCONFIG: true VALIDATE_XML: true VALIDATE_YAML: true - - - name: EditorConfig checks - uses: super-linter/super-linter/slim@v7 - env: - DEFAULT_BRANCH: main - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - IGNORE_GENERATED_FILES: true - IGNORE_GITIGNORED_FILES: true - # EditorConfig validation - VALIDATE_EDITORCONFIG: true - EDITORCONFIG_FILE_NAME: editorconfig-checker.json diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 37cfdc5c1e..c3f8fdbaf6 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -27,7 +27,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: web/** diff --git a/api/.env.example b/api/.env.example index 6ead14e9b0..daa0df535b 100644 --- a/api/.env.example +++ b/api/.env.example @@ -142,8 +142,10 @@ WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,* # Vector database configuration -# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. VECTOR_STORE=weaviate +# Prefix used to create collection name in vector database +VECTOR_INDEX_NAME_PREFIX=Vector_index # Weaviate configuration WEAVIATE_ENDPOINT=http://localhost:8080 diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 3c349060ca..587ea55ca7 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -85,6 +85,11 @@ class VectorStoreConfig(BaseSettings): default=False, ) + VECTOR_INDEX_NAME_PREFIX: Optional[str] = Field( + description="Prefix used to create collection name in vector database", + default="Vector_index", + ) + class KeywordStoreConfig(BaseSettings): KEYWORD_STORE: str = Field( diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 70d6216497..4eef9fed43 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,4 +1,4 @@ -from datetime import UTC, datetime +from datetime import datetime import pytz # pip install pytz from flask_login import current_user @@ -19,6 +19,7 @@ from fields.conversation_fields import ( conversation_pagination_fields, conversation_with_summary_pagination_fields, ) +from libs.datetime_utils import naive_utc_now from libs.helper import DatetimeString from libs.login import login_required from models import Conversation, EndUser, Message, MessageAnnotation @@ -315,7 +316,7 @@ def _get_conversation(app_model, conversation_id): raise NotFound("Conversation Not Exists.") if not conversation.read_at: - conversation.read_at = datetime.now(UTC).replace(tzinfo=None) + conversation.read_at = naive_utc_now() conversation.read_account_id = current_user.id db.session.commit() diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 3c3a359eeb..358a5e8cdb 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,5 +1,3 @@ -from datetime import UTC, datetime - from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -10,6 +8,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.app_fields import app_site_fields +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import Site @@ -77,7 +76,7 @@ class AppSite(Resource): setattr(site, attr_name, value) site.updated_by = current_user.id - site.updated_at = datetime.now(UTC).replace(tzinfo=None) + site.updated_at = naive_utc_now() db.session.commit() return site @@ -101,7 +100,7 @@ class AppSiteAccessTokenReset(Resource): site.code = Site.generate_code(16) site.updated_by = current_user.id - site.updated_at = datetime.now(UTC).replace(tzinfo=None) + site.updated_at = naive_utc_now() db.session.commit() return site diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 1795563ff7..2562fb5eb8 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,5 +1,3 @@ -import datetime - from flask import request from flask_restful import Resource, reqparse @@ -7,6 +5,7 @@ from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from libs.helper import StrLen, email, extract_remote_ip, timezone from models.account import AccountStatus from services.account_service import AccountService, RegisterService @@ -65,7 +64,7 @@ class ActivateApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 395367c9e2..d0a4f3ff6d 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,5 +1,4 @@ import logging -from datetime import UTC, datetime from typing import Optional import requests @@ -13,6 +12,7 @@ from configs import dify_config from constants.languages import languages from events.tenant_event import tenant_was_created from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from libs.helper import extract_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models import Account @@ -110,7 +110,7 @@ class OAuthCallback(Resource): if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() try: diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 7b0d9373cf..b49f8affc8 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,4 +1,3 @@ -import datetime import json from flask import request @@ -15,6 +14,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import DataSourceOauthBinding, Document from services.dataset_service import DatasetService, DocumentService @@ -88,7 +88,7 @@ class DataSourceApi(Resource): if action == "enable": if data_source_binding.disabled: data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.add(data_source_binding) db.session.commit() else: @@ -97,7 +97,7 @@ class DataSourceApi(Resource): if action == "disable": if not data_source_binding.disabled: data_source_binding.disabled = True - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.add(data_source_binding) db.session.commit() else: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 35d912bfcc..6e039d735b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,7 +1,6 @@ import json import logging from argparse import ArgumentTypeError -from datetime import UTC, datetime from typing import cast from flask import request @@ -50,6 +49,7 @@ from fields.document_fields import ( document_status_fields, document_with_segments_fields, ) +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models.dataset import DocumentPipelineExecutionLog @@ -752,7 +752,7 @@ class DocumentProcessingApi(DocumentResource): raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id - document.paused_at = datetime.now(UTC).replace(tzinfo=None) + document.paused_at = naive_utc_now() document.is_paused = True db.session.commit() @@ -832,7 +832,7 @@ class DocumentMetadataApi(DocumentResource): document.doc_metadata[key] = value document.doc_type = doc_type - document.updated_at = datetime.now(UTC).replace(tzinfo=None) + document.updated_at = naive_utc_now() db.session.commit() return {"result": "success", "message": "Document metadata updated."}, 200 diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 4367da1162..4842fefc57 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,5 +1,4 @@ import logging -from datetime import UTC, datetime from flask_login import current_user from flask_restful import reqparse @@ -27,6 +26,7 @@ from core.errors.error import ( from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper +from libs.datetime_utils import naive_utc_now from libs.helper import uuid_value from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -51,7 +51,7 @@ class CompletionApi(InstalledAppResource): streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) + installed_app.last_used_at = naive_utc_now() db.session.commit() try: @@ -111,7 +111,7 @@ class ChatApi(InstalledAppResource): args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) + installed_app.last_used_at = naive_utc_now() db.session.commit() try: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 9d0c08564e..29111fb865 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,5 +1,4 @@ import logging -from datetime import UTC, datetime from typing import Any from flask import request @@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService @@ -122,7 +122,7 @@ class InstalledAppsListApi(Resource): tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.now(UTC).replace(tzinfo=None), + last_used_at=naive_utc_now(), ) db.session.add(new_installed_app) db.session.commit() diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 1f22e3fd01..7f7e64a59c 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,5 +1,3 @@ -import datetime - import pytz from flask import request from flask_login import current_user @@ -35,6 +33,7 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from fields.member_fields import account_fields +from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, email, extract_remote_ip, timezone from libs.login import login_required from models import AccountIntegrate, InvitationCode @@ -80,7 +79,7 @@ class AccountInitApi(Resource): raise InvalidInvitationCodeError() invitation_code.status = "used" - invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + invitation_code.used_at = naive_utc_now() invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id @@ -88,7 +87,7 @@ class AccountInitApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = "active" - account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() return {"result": "success"} diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index e41375e52b..c70bf84d2a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -29,7 +29,7 @@ from libs.login import login_required from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from services.tools.mcp_tools_mange_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService from services.tools.tool_labels_service import ToolLabelsService from services.tools.tools_manage_service import ToolCommonService from services.tools.tools_transform_service import ToolTransformService diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 5b919a68d4..eeed321430 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,6 +1,6 @@ import time from collections.abc import Callable -from datetime import UTC, datetime, timedelta +from datetime import timedelta from enum import Enum from functools import wraps from typing import Optional @@ -15,6 +15,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from libs.login import _get_user from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.dataset import Dataset, RateLimitLog @@ -256,7 +257,7 @@ def validate_and_get_api_token(scope: str | None = None): if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - current_time = datetime.now(UTC).replace(tzinfo=None) + current_time = naive_utc_now() cutoff_time = current_time - timedelta(minutes=1) with Session(db.engine, expire_on_commit=False) as session: update_stmt = ( diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 85fafe6980..d50cf1c941 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,7 +1,6 @@ import json import logging from collections.abc import Generator -from datetime import UTC, datetime from typing import Optional, Union, cast from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom @@ -25,6 +24,7 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models import Account from models.enums import CreatorUserRole from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile @@ -184,7 +184,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.commit() db.session.refresh(conversation) else: - conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + conversation.updated_at = naive_utc_now() db.session.commit() message = Message( diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index c2bc1ffbe3..80ff5f693c 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -7,6 +7,7 @@ from core.model_runtime.entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, + TextPromptMessageContent, VideoPromptMessageContent, ) from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes @@ -44,11 +45,44 @@ def to_prompt_message_content( *, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> PromptMessageContentUnionTypes: + """ + Convert a file to prompt message content. + + This function converts files to their appropriate prompt message content types. + For supported file types (IMAGE, AUDIO, VIDEO, DOCUMENT), it creates the + corresponding message content with proper encoding/URL. + + For unsupported file types, instead of raising an error, it returns a + TextPromptMessageContent with a descriptive message about the file. + + Args: + f: The file to convert + image_detail_config: Optional detail configuration for image files + + Returns: + PromptMessageContentUnionTypes: The appropriate message content type + + Raises: + ValueError: If file extension or mime_type is missing + """ if f.extension is None: raise ValueError("Missing file extension") if f.mime_type is None: raise ValueError("Missing file mime_type") + prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { + FileType.IMAGE: ImagePromptMessageContent, + FileType.AUDIO: AudioPromptMessageContent, + FileType.VIDEO: VideoPromptMessageContent, + FileType.DOCUMENT: DocumentPromptMessageContent, + } + + # Check if file type is supported + if f.type not in prompt_class_map: + # For unsupported file types, return a text description + return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]") + + # Process supported file types params = { "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "", "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "", @@ -58,17 +92,7 @@ def to_prompt_message_content( if f.type == FileType.IMAGE: params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { - FileType.IMAGE: ImagePromptMessageContent, - FileType.AUDIO: AudioPromptMessageContent, - FileType.VIDEO: VideoPromptMessageContent, - FileType.DOCUMENT: DocumentPromptMessageContent, - } - - try: - return prompt_class_map[f.type].model_validate(params) - except KeyError: - raise ValueError(f"file type {f.type} is not supported") + return prompt_class_map[f.type].model_validate(params) def download(f: File, /): diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index cd55dbf64f..00d5a25956 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -8,7 +8,7 @@ from core.mcp.types import ( OAuthTokens, ) from models.tools import MCPToolProvider -from services.tools.mcp_tools_mange_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService LATEST_PROTOCOL_VERSION = "1.0" diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index e9036de8c6..f7aa7bbd7b 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -68,15 +68,17 @@ class MCPClient: } parsed_url = urlparse(self.server_url) - path = parsed_url.path - method_name = path.rstrip("/").split("/")[-1] if path else "" - try: + path = parsed_url.path or "" + method_name = path.removesuffix("/").lower() + if method_name in connection_methods: client_factory = connection_methods[method_name] self.connect_server(client_factory, method_name) - except KeyError: + else: try: + logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.") self.connect_server(sse_client, "sse") except MCPConnectionError: + logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") self.connect_server(streamablehttp_client, "mcp") def connect_server( @@ -91,7 +93,7 @@ class MCPClient: else {} ) self._streams_context = client_factory(url=self.server_url, headers=headers) - if self._streams_context is None: + if not self._streams_context: raise MCPConnectionError("Failed to create connection context") # Use exit_stack to manage context managers properly @@ -141,10 +143,11 @@ class MCPClient: try: # ExitStack will handle proper cleanup of all managed context managers self.exit_stack.close() + except Exception as e: + logging.exception("Error during cleanup") + raise ValueError(f"Error during cleanup: {e}") + finally: self._session = None self._session_context = None self._streams_context = None self._initialized = False - except Exception as e: - logging.exception("Error during cleanup") - raise ValueError(f"Error during cleanup: {e}") diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index ffda0885d4..8b3ce0c448 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -3,7 +3,7 @@ import json import logging import os from datetime import datetime, timedelta -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry import trace @@ -142,11 +142,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): raise def workflow_trace(self, trace_info: WorkflowTraceInfo): - if trace_info.message_data is None: - return - workflow_metadata = { - "workflow_id": trace_info.workflow_run_id or "", + "workflow_run_id": trace_info.workflow_run_id or "", "message_id": trace_info.message_id or "", "workflow_app_log_id": trace_info.workflow_app_log_id or "", "status": trace_info.workflow_run_status or "", @@ -156,7 +153,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } workflow_metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = uuid_to_trace_id(trace_info.workflow_run_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -213,7 +210,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): if model: node_metadata["ls_model_name"] = model - outputs = json.loads(node_execution.outputs).get("usage", {}) + outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) if usage_data: node_metadata["total_tokens"] = usage_data.get("total_tokens", 0) @@ -236,31 +233,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(created_at), + context=trace.set_span_in_context(trace.NonRecordingSpan(context)), ) try: if node_execution.node_type == "llm": + llm_attributes: dict[str, Any] = { + SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False), + } provider = process_data.get("model_provider") model = process_data.get("model_name") if provider: - node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider) + llm_attributes[SpanAttributes.LLM_PROVIDER] = provider if model: - node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model) - - outputs = json.loads(node_execution.outputs).get("usage", {}) + llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model + outputs = ( + json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} + ) usage_data = ( process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) ) if usage_data: - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0) - ) - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0) - ) - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0) + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0) + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0) + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get( + "completion_tokens", 0 ) + llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) + node_span.set_attributes(llm_attributes) finally: node_span.end(end_time=datetime_to_nanos(finished_at)) finally: @@ -352,25 +352,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False), SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, } - - if isinstance(trace_info.inputs, list): - for i, msg in enumerate(trace_info.inputs): - if isinstance(msg, dict): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get( - "role", "user" - ) - # todo: handle assistant and tool role messages, as they don't always - # have a text field, but may have a tool_calls field instead - # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', - # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} - elif isinstance(trace_info.inputs, dict): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs) - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" - elif isinstance(trace_info.inputs, str): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" - + llm_attributes.update(self._construct_llm_attributes(trace_info.inputs)) if trace_info.total_tokens is not None and trace_info.total_tokens > 0: llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens if trace_info.message_tokens is not None and trace_info.message_tokens > 0: @@ -724,3 +706,24 @@ class ArizePhoenixDataTrace(BaseTraceInstance): .all() ) return workflow_nodes + + def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: + """Helper method to construct LLM attributes with passed prompts.""" + attributes = {} + if isinstance(prompts, list): + for i, msg in enumerate(prompts): + if isinstance(msg, dict): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user") + # todo: handle assistant and tool role messages, as they don't always + # have a text field, but may have a tool_calls field instead + # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', + # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} + elif isinstance(prompts, dict): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts) + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" + elif isinstance(prompts, str): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" + + return attributes diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 095752ea8e..6f3e15d166 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -233,6 +233,12 @@ class AnalyticdbVectorOpenAPI: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"metadata_->>'document_id' IN ({document_ids})" + score_threshold = kwargs.get("score_threshold") or 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, @@ -245,7 +251,7 @@ class AnalyticdbVectorOpenAPI: vector=query_vector, content=None, top_k=kwargs.get("top_k", 4), - filter=None, + filter=where_clause, ) response = self._client.query_collection_data(request) documents = [] @@ -265,6 +271,11 @@ class AnalyticdbVectorOpenAPI: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"metadata_->>'document_id' IN ({document_ids})" score_threshold = float(kwargs.get("score_threshold") or 0.0) request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, @@ -277,7 +288,7 @@ class AnalyticdbVectorOpenAPI: vector=None, content=query, top_k=kwargs.get("top_k", 4), - filter=None, + filter=where_clause, ) response = self._client.query_collection_data(request) documents = [] diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 44cc5d3e98..ad39717183 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -147,10 +147,17 @@ class ElasticSearchVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str = {"match": {Field.CONTENT_KEY.value: query}} + query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}} document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: - query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore + query_str = { + "bool": { + "must": {"match": {Field.CONTENT_KEY.value: query}}, + "filter": {"terms": {"metadata.document_id": document_ids_filter}}, + } + } + results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) docs = [] for hit in results["hits"]["hits"]: diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 0fb1bcb2e0..bcaf299892 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -102,6 +102,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = text.split() else: splits = text.split(separator) + splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)] else: splits = list(text) splits = [s for s in splits if (s not in {"", "\n"})] diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d61856a8f5..7822bc389c 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -21,7 +21,7 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.workflow.entities.variable_pool import VariablePool -from services.tools.mcp_tools_mange_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: from core.workflow.nodes.tool.entities import ToolEntity diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index a4616eda69..704eb6a3ac 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -270,7 +270,14 @@ class AgentNode(BaseNode): ) extra = tool.get("extra", {}) - runtime_variable_pool = variable_pool if self._node_data.version != "1" else None + + # This is an issue that caused problems before. + # Logically, we shouldn't use the node_data.version field for judgment + # But for backward compatibility with historical data + # this version field judgment is still preserved here. + runtime_variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version != "1": + runtime_variable_pool = variable_pool tool_runtime = ToolManager.get_agent_tool_runtime( self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool ) diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 075a41fb2f..11b11068e7 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -13,6 +13,10 @@ class AgentNodeData(BaseNodeData): agent_strategy_name: str agent_strategy_label: str # redundancy memory: MemoryConfig | None = None + # The version of the tool parameter. + # If this value is None, it indicates this is a previous version + # and requires using the legacy parameter parsing rules. + tool_node_version: str | None = None class AgentInput(BaseModel): value: Union[list[str], list[ToolSelector], Any] diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index c7cc077054..7dbac7851d 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -117,7 +117,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_model_config: ModelConfig + metadata_model_config: Optional[ModelConfig] = None metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 4e9a38f552..5f092dc2f1 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -509,6 +509,8 @@ class KnowledgeRetrievalNode(BaseNode): # get all metadata field metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] + if node_data.metadata_model_config is None: + raise ValueError("metadata_model_config is required") # get metadata model instance and fetch model config model_instance, model_config = self.get_model_config(node_data.metadata_model_config) # fetch prompt messages @@ -701,7 +703,7 @@ class KnowledgeRetrievalNode(BaseNode): ) def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): - model_mode = ModelMode(node_data.metadata_model_config.mode) + model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore input_text = query prompt_messages: list[LLMNodeChatModelMessage] = [] diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index 7e043253eb..5778f89ac3 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -75,6 +75,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { }, NodeType.TOOL: { LATEST_VERSION: ToolNode, + # This is an issue that caused problems before. + # Logically, we shouldn't use two different versions to point to the same class here, + # but in order to maintain compatibility with historical data, this approach has been retained. "2": ToolNode, "1": ToolNode, }, @@ -125,6 +128,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { }, NodeType.AGENT: { LATEST_VERSION: AgentNode, + # This is an issue that caused problems before. + # Logically, we shouldn't use two different versions to point to the same class here, + # but in order to maintain compatibility with historical data, this approach has been retained. "2": AgentNode, "1": AgentNode, }, diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 88c5160d14..f0a44d919b 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -59,6 +59,10 @@ class ToolNodeData(BaseNodeData, ToolEntity): return typ tool_parameters: dict[str, ToolInput] + # The version of the tool parameter. + # If this value is None, it indicates this is a previous version + # and requires using the legacy parameter parsing rules. + tool_node_version: str | None = None @field_validator("tool_parameters", mode="before") @classmethod diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index c565ad15c1..140fe71f60 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -70,7 +70,13 @@ class ToolNode(BaseNode): try: from core.tools.tool_manager import ToolManager - variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None + # This is an issue that caused problems before. + # Logically, we shouldn't use the node_data.version field for judgment + # But for backward compatibility with historical data + # this version field judgment is still preserved here. + variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version != "1": + variable_pool = self.graph_runtime_state.variable_pool tool_runtime = ToolManager.get_workflow_tool_runtime( self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool ) diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 3e591ef885..f844aada95 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from dataclasses import dataclass -from datetime import UTC, datetime +from datetime import datetime from typing import Any, Optional, Union from uuid import uuid4 @@ -71,7 +71,7 @@ class WorkflowCycleManager: workflow_version=self._workflow_info.version, graph=self._workflow_info.graph_data, inputs=inputs, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) return self._save_and_cache_workflow_execution(execution) @@ -356,7 +356,7 @@ class WorkflowCycleManager: created_at: Optional[datetime] = None, ) -> WorkflowNodeExecution: """Create a node execution from an event.""" - now = datetime.now(UTC).replace(tzinfo=None) + now = naive_utc_now() created_at = created_at or now metadata = { @@ -403,7 +403,7 @@ class WorkflowCycleManager: handle_special_values: bool = False, ) -> None: """Update node execution with completion data.""" - finished_at = datetime.now(UTC).replace(tzinfo=None) + finished_at = naive_utc_now() elapsed_time = (finished_at - event.start_at).total_seconds() # Process data diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 8a677f6b6f..cb48bd92a0 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -1,4 +1,3 @@ -import datetime import logging import time @@ -8,6 +7,7 @@ from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.event_handlers.document_index_event import document_index_created from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Document @@ -33,7 +33,7 @@ def handle(sender, **kwargs): raise NotFound("Document not found") document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 7448fd4a6b..81eec94da4 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from datetime import UTC, datetime, timedelta +from datetime import timedelta from typing import Optional from azure.identity import ChainedTokenCredential, DefaultAzureCredential @@ -8,6 +8,7 @@ from azure.storage.blob import AccountSasPermissions, BlobServiceClient, Resourc from configs import dify_config from extensions.ext_redis import redis_client from extensions.storage.base_storage import BaseStorage +from libs.datetime_utils import naive_utc_now class AzureBlobStorage(BaseStorage): @@ -78,7 +79,7 @@ class AzureBlobStorage(BaseStorage): account_key=self.account_key or "", resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), + expiry=naive_utc_now() + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) return BlobServiceClient(account_url=self.account_url or "", credential=sas_token) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 4f25cc64b0..adf4cf68ee 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -149,9 +149,7 @@ def _build_from_local_file( if strict_type_validation and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( - FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type - ) + file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type return File( id=mapping.get("id"), @@ -200,9 +198,7 @@ def _build_from_remote_url( raise ValueError("Detected file type does not match the specified type. Please verify the file.") file_type = ( - FileType(specified_type) - if specified_type and specified_type != FileType.CUSTOM.value - else detected_file_type + FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type ) return File( @@ -287,9 +283,7 @@ def _build_from_tool_file( if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( - FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type - ) + file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type return File( id=mapping.get("id"), diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 218109522d..78f827584c 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,4 +1,3 @@ -import datetime import urllib.parse from typing import Any @@ -6,6 +5,7 @@ import requests from flask_login import current_user from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -75,7 +75,7 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -115,7 +115,7 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -154,7 +154,7 @@ class NotionOAuth(OAuthDataSource): } data_source_binding.source_info = new_source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.commit() else: raise ValueError("Data source binding not found") diff --git a/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py new file mode 100644 index 0000000000..3bdbafda7c --- /dev/null +++ b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py @@ -0,0 +1,51 @@ +"""update models + +Revision ID: 1a83934ad6d1 +Revises: 71f5020c6470 +Create Date: 2025-07-21 09:35:48.774794 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1a83934ad6d1' +down_revision = '71f5020c6470' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: + batch_op.alter_column('server_identifier', + existing_type=sa.VARCHAR(length=24), + type_=sa.String(length=64), + existing_nullable=False) + + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.alter_column('tool_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=128), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.alter_column('tool_name', + existing_type=sa.String(length=128), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: + batch_op.alter_column('server_identifier', + existing_type=sa.String(length=64), + type_=sa.VARCHAR(length=24), + existing_nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 5017472e89..a26788df0d 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -287,7 +287,7 @@ class Dataset(Base): @staticmethod def gen_collection_name_by_id(dataset_id: str) -> str: normalized_dataset_id = dataset_id.replace("-", "_") - return f"Vector_index_{normalized_dataset_id}_Node" + return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node" class DatasetProcessRule(Base): diff --git a/api/models/task.py b/api/models/task.py index d853c1dd9a..1a4b606ff5 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,7 +1,6 @@ -from datetime import UTC, datetime - from celery import states # type: ignore +from libs.datetime_utils import naive_utc_now from models.base import Base from .engine import db @@ -18,8 +17,8 @@ class CeleryTask(Base): result = db.Column(db.PickleType, nullable=True) date_done = db.Column( db.DateTime, - default=lambda: datetime.now(UTC).replace(tzinfo=None), - onupdate=lambda: datetime.now(UTC).replace(tzinfo=None), + default=lambda: naive_utc_now(), + onupdate=lambda: naive_utc_now(), nullable=True, ) traceback = db.Column(db.Text, nullable=True) @@ -39,4 +38,4 @@ class CeleryTaskSet(Base): id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) taskset_id = db.Column(db.String(155), unique=True) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True) + date_done = db.Column(db.DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 1ed080ea23..2f94b4bb87 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -253,7 +253,7 @@ class MCPToolProvider(Base): # name of the mcp provider name: Mapped[str] = mapped_column(db.String(40), nullable=False) # server identifier of the mcp provider - server_identifier: Mapped[str] = mapped_column(db.String(24), nullable=False) + server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False) # encrypted url of the mcp provider server_url: Mapped[str] = mapped_column(db.Text, nullable=False) # hash of server_url for uniqueness check @@ -357,7 +357,7 @@ class ToolModelInvoke(Base): # type tool_type = db.Column(db.String(40), nullable=False) # tool name - tool_name = db.Column(db.String(40), nullable=False) + tool_name = db.Column(db.String(128), nullable=False) # invoke parameters model_parameters = db.Column(db.Text, nullable=False) # prompt messages diff --git a/api/models/workflow.py b/api/models/workflow.py index e36b0f4ecf..dd123478f8 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Mapping, Sequence -from datetime import UTC, datetime +from datetime import datetime from enum import Enum, StrEnum from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 @@ -16,6 +16,7 @@ from core.variables.variables import FloatVariable, IntegerVariable, StringVaria from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.enums import NodeType from factories.variable_factory import TypeMismatchError, build_segment_with_type +from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from ._workflow_exc import NodeNotFoundError, WorkflowDataError @@ -139,7 +140,7 @@ class Workflow(Base): updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, - default=datetime.now(UTC).replace(tzinfo=None), + default=naive_utc_now(), server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( @@ -185,7 +186,7 @@ class Workflow(Base): workflow.rag_pipeline_variables = rag_pipeline_variables or [] workflow.marked_name = marked_name workflow.marked_comment = marked_comment - workflow.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow.created_at = naive_utc_now() workflow.updated_at = workflow.created_at return workflow @@ -938,7 +939,7 @@ _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"]) def _naive_utc_datetime(): - return datetime.now(UTC).replace(tzinfo=None) + return naive_utc_now() class WorkflowDraftVariable(Base): diff --git a/api/services/account_service.py b/api/services/account_service.py index 4d5366f47f..352efb2f0c 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -17,6 +17,7 @@ from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback +from libs.datetime_utils import naive_utc_now from libs.helper import RateLimiter, TokenManager from libs.passport import PassportService from libs.password import compare_password, hash_password, valid_password @@ -135,8 +136,8 @@ class AccountService: available_ta.current = True db.session.commit() - if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): - account.last_active_at = datetime.now(UTC).replace(tzinfo=None) + if naive_utc_now() - account.last_active_at > timedelta(minutes=10): + account.last_active_at = naive_utc_now() db.session.commit() return cast(Account, account) @@ -180,7 +181,7 @@ class AccountService: if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() @@ -318,7 +319,7 @@ class AccountService: # If it exists, update the record account_integrate.open_id = open_id account_integrate.encrypted_token = "" # todo - account_integrate.updated_at = datetime.now(UTC).replace(tzinfo=None) + account_integrate.updated_at = naive_utc_now() else: # If it does not exist, create a new record account_integrate = AccountIntegrate( @@ -353,7 +354,7 @@ class AccountService: @staticmethod def update_login_info(account: Account, *, ip_address: str) -> None: """Update last login time and ip""" - account.last_login_at = datetime.now(UTC).replace(tzinfo=None) + account.last_login_at = naive_utc_now() account.last_login_ip = ip_address db.session.add(account) db.session.commit() @@ -1066,15 +1067,6 @@ class TenantService: target_member_join.role = new_role db.session.commit() - @staticmethod - def dissolve_tenant(tenant: Tenant, operator: Account) -> None: - """Dissolve tenant""" - if not TenantService.check_member_permission(tenant, operator, operator, "remove"): - raise NoPermissionError("No permission to dissolve tenant.") - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() - db.session.delete(tenant) - db.session.commit() - @staticmethod def get_custom_config(tenant_id: str) -> dict: tenant = db.get_or_404(Tenant, tenant_id) @@ -1117,7 +1109,7 @@ class RegisterService: ) account.last_login_ip = ip_address - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) @@ -1158,7 +1150,7 @@ class RegisterService: is_setup=is_setup, ) account.status = AccountStatus.ACTIVE.value if not status else status.value - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) diff --git a/api/services/app_service.py b/api/services/app_service.py index 0a08f345df..3494b2796b 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,5 @@ import json import logging -from datetime import UTC, datetime from typing import Optional, cast from flask_login import current_user @@ -17,6 +16,7 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.account import Account from models.model import App, AppMode, AppModelConfig, Site from models.tools import ApiToolProvider @@ -235,7 +235,7 @@ class AppService: app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) app.max_active_requests = args.get("max_active_requests") app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -249,7 +249,7 @@ class AppService: """ app.name = name app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -265,7 +265,7 @@ class AppService: app.icon = icon app.icon_background = icon_background app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -282,7 +282,7 @@ class AppService: app.enable_site = enable_site app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -299,7 +299,7 @@ class AppService: app.enable_api = enable_api app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index afdaa49465..40097d5ed5 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,5 +1,4 @@ from collections.abc import Callable, Sequence -from datetime import UTC, datetime from typing import Optional, Union from sqlalchemy import asc, desc, func, or_, select @@ -8,6 +7,7 @@ from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ConversationVariable from models.account import Account @@ -113,7 +113,7 @@ class ConversationService: return cls.auto_generate_name(app_model, conversation) else: conversation.name = name - conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + conversation.updated_at = naive_utc_now() db.session.commit() return conversation @@ -169,7 +169,7 @@ class ConversationService: conversation = cls.get_conversation(app_model, conversation_id, user) conversation.is_deleted = True - conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + conversation.updated_at = naive_utc_now() db.session.commit() @classmethod diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 09dced8dba..924006e601 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -26,6 +26,7 @@ from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper +from libs.datetime_utils import naive_utc_now from models.account import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -484,7 +485,7 @@ class DatasetService: # Add metadata fields filtered_data["updated_by"] = user.id - filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + filtered_data["updated_at"] = naive_utc_now() # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] # update icon info @@ -1175,7 +1176,7 @@ class DocumentService: # update document to be paused document.is_paused = True document.paused_by = current_user.id - document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.paused_at = naive_utc_now() db.session.add(document) db.session.commit() diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index eb50d79494..06a4c22117 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,5 @@ import json from copy import deepcopy -from datetime import UTC, datetime from typing import Any, Optional, Union, cast from urllib.parse import urlparse @@ -11,6 +10,7 @@ from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import ( Dataset, ExternalKnowledgeApis, @@ -120,7 +120,7 @@ class ExternalDatasetService: external_knowledge_api.description = args.get("description", "") external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False) external_knowledge_api.updated_by = user_id - external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None) + external_knowledge_api.updated_at = naive_utc_now() db.session.commit() return external_knowledge_api diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_manage_service.py similarity index 95% rename from api/services/tools/mcp_tools_mange_service.py rename to api/services/tools/mcp_tools_manage_service.py index fda6da5983..e0e256912e 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -70,16 +70,15 @@ class MCPToolManageService: MCPToolProvider.server_url_hash == server_url_hash, MCPToolProvider.server_identifier == server_identifier, ), - MCPToolProvider.tenant_id == tenant_id, ) .first() ) if existing_provider: if existing_provider.name == name: raise ValueError(f"MCP tool {name} already exists") - elif existing_provider.server_url_hash == server_url_hash: + if existing_provider.server_url_hash == server_url_hash: raise ValueError(f"MCP tool {server_url} already exists") - elif existing_provider.server_identifier == server_identifier: + if existing_provider.server_identifier == server_identifier: raise ValueError(f"MCP tool {server_identifier} already exists") encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) mcp_tool = MCPToolProvider( @@ -111,15 +110,14 @@ class MCPToolManageService: ] @classmethod - def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str): + def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - try: with MCPClient( mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True ) as mcp_client: tools = mcp_client.list_tools() - except MCPAuthError as e: + except MCPAuthError: raise ValueError("Please auth the tool first") except MCPError as e: raise ValueError(f"Failed to connect to MCP server: {e}") @@ -184,12 +182,11 @@ class MCPToolManageService: error_msg = str(e.orig) if "unique_mcp_provider_name" in error_msg: raise ValueError(f"MCP tool {name} already exists") - elif "unique_mcp_provider_server_url" in error_msg: + if "unique_mcp_provider_server_url" in error_msg: raise ValueError(f"MCP tool {server_url} already exists") - elif "unique_mcp_provider_server_identifier" in error_msg: + if "unique_mcp_provider_server_identifier" in error_msg: raise ValueError(f"MCP tool {server_identifier} already exists") - else: - raise + raise @classmethod def update_mcp_provider_credentials( diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0496c33925..89bb504437 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,7 +2,6 @@ import json import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import UTC, datetime from typing import Any, Optional, cast from uuid import uuid4 @@ -33,6 +32,7 @@ from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings +from libs.datetime_utils import naive_utc_now from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -232,7 +232,7 @@ class WorkflowService: workflow.graph = json.dumps(graph) workflow.features = json.dumps(features) workflow.updated_by = account.id - workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.updated_at = naive_utc_now() workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables @@ -268,7 +268,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=draft_workflow.type, - version=Workflow.version_from_datetime(datetime.now(UTC).replace(tzinfo=None)), + version=Workflow.version_from_datetime(naive_utc_now()), graph=draft_workflow.graph, created_by=account.id, environment_variables=draft_workflow.environment_variables, @@ -524,8 +524,8 @@ class WorkflowService: node_type=node.type_, title=node.title, elapsed_time=time.perf_counter() - start_at, - created_at=datetime.now(UTC).replace(tzinfo=None), - finished_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), + finished_at=naive_utc_now(), ) if run_succeeded and node_run_result: @@ -622,7 +622,7 @@ class WorkflowService: setattr(workflow, field, value) workflow.updated_by = account_id - workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.updated_at = naive_utc_now() return workflow diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 55cac6a9af..a85aab0bb7 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -1,4 +1,3 @@ -import datetime import logging import time @@ -8,6 +7,7 @@ from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document from services.feature_service import FeatureService @@ -53,7 +53,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.stopped_at = naive_utc_now() db.session.add(document) db.session.commit() db.session.close() @@ -68,7 +68,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index 077ffe3408..f484fb22d3 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -26,8 +26,15 @@ redis_mock.hgetall = MagicMock(return_value={}) redis_mock.hdel = MagicMock() redis_mock.incr = MagicMock(return_value=1) +# Add the API directory to Python path to ensure proper imports +import sys + +sys.path.insert(0, PROJECT_DIR) + # apply the mock to the Redis client in the Flask app -redis_patcher = patch("extensions.ext_redis.redis_client", redis_mock) +from extensions import ext_redis + +redis_patcher = patch.object(ext_redis, "redis_client", redis_mock) redis_patcher.start() diff --git a/api/tests/unit_tests/services/auth/__init__.py b/api/tests/unit_tests/services/auth/__init__.py new file mode 100644 index 0000000000..852a892730 --- /dev/null +++ b/api/tests/unit_tests/services/auth/__init__.py @@ -0,0 +1 @@ +# API authentication service test module diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_base.py b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py new file mode 100644 index 0000000000..b5d91ef3fb --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py @@ -0,0 +1,49 @@ +import pytest + +from services.auth.api_key_auth_base import ApiKeyAuthBase + + +class ConcreteApiKeyAuth(ApiKeyAuthBase): + """Concrete implementation for testing abstract base class""" + + def validate_credentials(self): + return True + + +class TestApiKeyAuthBase: + def test_should_store_credentials_on_init(self): + """Test that credentials are properly stored during initialization""" + credentials = {"api_key": "test_key", "auth_type": "bearer"} + auth = ConcreteApiKeyAuth(credentials) + assert auth.credentials == credentials + + def test_should_not_instantiate_abstract_class(self): + """Test that ApiKeyAuthBase cannot be instantiated directly""" + credentials = {"api_key": "test_key"} + + with pytest.raises(TypeError) as exc_info: + ApiKeyAuthBase(credentials) + + assert "Can't instantiate abstract class" in str(exc_info.value) + assert "validate_credentials" in str(exc_info.value) + + def test_should_allow_subclass_implementation(self): + """Test that subclasses can properly implement the abstract method""" + credentials = {"api_key": "test_key", "auth_type": "bearer"} + auth = ConcreteApiKeyAuth(credentials) + + # Should not raise any exception + result = auth.validate_credentials() + assert result is True + + def test_should_handle_empty_credentials(self): + """Test initialization with empty credentials""" + credentials = {} + auth = ConcreteApiKeyAuth(credentials) + assert auth.credentials == {} + + def test_should_handle_none_credentials(self): + """Test initialization with None credentials""" + credentials = None + auth = ConcreteApiKeyAuth(credentials) + assert auth.credentials is None diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py new file mode 100644 index 0000000000..9d9cb7c6d5 --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py @@ -0,0 +1,81 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from services.auth.api_key_auth_factory import ApiKeyAuthFactory +from services.auth.auth_type import AuthType + + +class TestApiKeyAuthFactory: + """Test cases for ApiKeyAuthFactory""" + + @pytest.mark.parametrize( + ("provider", "auth_class_path"), + [ + (AuthType.FIRECRAWL, "services.auth.firecrawl.firecrawl.FirecrawlAuth"), + (AuthType.WATERCRAWL, "services.auth.watercrawl.watercrawl.WatercrawlAuth"), + (AuthType.JINA, "services.auth.jina.jina.JinaAuth"), + ], + ) + def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path): + """Test getting auth factory for all valid providers""" + with patch(auth_class_path) as mock_auth: + auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) + assert auth_class == mock_auth + + @pytest.mark.parametrize( + "invalid_provider", + [ + "invalid_provider", + "", + None, + 123, + "UNSUPPORTED", + ], + ) + def test_get_apikey_auth_factory_invalid_providers(self, invalid_provider): + """Test getting auth factory with various invalid providers""" + with pytest.raises(ValueError) as exc_info: + ApiKeyAuthFactory.get_apikey_auth_factory(invalid_provider) + assert str(exc_info.value) == "Invalid provider" + + @pytest.mark.parametrize( + ("credentials_return_value", "expected_result"), + [ + (True, True), + (False, False), + ], + ) + @patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory") + def test_validate_credentials_delegates_to_auth_instance( + self, mock_get_factory, credentials_return_value, expected_result + ): + """Test that validate_credentials delegates to auth instance correctly""" + # Arrange + mock_auth_instance = MagicMock() + mock_auth_instance.validate_credentials.return_value = credentials_return_value + mock_auth_class = MagicMock(return_value=mock_auth_instance) + mock_get_factory.return_value = mock_auth_class + + # Act + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"}) + result = factory.validate_credentials() + + # Assert + assert result is expected_result + mock_auth_instance.validate_credentials.assert_called_once() + + @patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory") + def test_validate_credentials_propagates_exceptions(self, mock_get_factory): + """Test that exceptions from auth instance are propagated""" + # Arrange + mock_auth_instance = MagicMock() + mock_auth_instance.validate_credentials.side_effect = Exception("Authentication error") + mock_auth_class = MagicMock(return_value=mock_auth_instance) + mock_get_factory.return_value = mock_auth_class + + # Act & Assert + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"}) + with pytest.raises(Exception) as exc_info: + factory.validate_credentials() + assert str(exc_info.value) == "Authentication error" diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py new file mode 100644 index 0000000000..f0e425e742 --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py @@ -0,0 +1,382 @@ +import json +from unittest.mock import Mock, patch + +import pytest + +from models.source import DataSourceApiKeyAuthBinding +from services.auth.api_key_auth_service import ApiKeyAuthService + + +class TestApiKeyAuthService: + """API key authentication service security tests""" + + def setup_method(self): + """Setup test fixtures""" + self.tenant_id = "test_tenant_123" + self.category = "search" + self.provider = "google" + self.binding_id = "binding_123" + self.mock_credentials = {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}} + self.mock_args = {"category": self.category, "provider": self.provider, "credentials": self.mock_credentials} + + @patch("services.auth.api_key_auth_service.db.session") + def test_get_provider_auth_list_success(self, mock_session): + """Test get provider auth list - success scenario""" + # Mock database query result + mock_binding = Mock() + mock_binding.tenant_id = self.tenant_id + mock_binding.provider = self.provider + mock_binding.disabled = False + + mock_session.query.return_value.filter.return_value.all.return_value = [mock_binding] + + result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) + + assert len(result) == 1 + assert result[0].tenant_id == self.tenant_id + mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding) + + @patch("services.auth.api_key_auth_service.db.session") + def test_get_provider_auth_list_empty(self, mock_session): + """Test get provider auth list - empty result""" + mock_session.query.return_value.filter.return_value.all.return_value = [] + + result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) + + assert result == [] + + @patch("services.auth.api_key_auth_service.db.session") + def test_get_provider_auth_list_filters_disabled(self, mock_session): + """Test get provider auth list - filters disabled items""" + mock_session.query.return_value.filter.return_value.all.return_value = [] + + ApiKeyAuthService.get_provider_auth_list(self.tenant_id) + + # Verify filter conditions include disabled.is_(False) + filter_call = mock_session.query.return_value.filter.call_args[0] + assert len(filter_call) == 2 # tenant_id and disabled filter conditions + + @patch("services.auth.api_key_auth_service.db.session") + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_success(self, mock_encrypter, mock_factory, mock_session): + """Test create provider auth - success scenario""" + # Mock successful auth validation + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + + # Mock encryption + encrypted_key = "encrypted_test_key_123" + mock_encrypter.encrypt_token.return_value = encrypted_key + + # Mock database operations + mock_session.add = Mock() + mock_session.commit = Mock() + + ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) + + # Verify factory class calls + mock_factory.assert_called_once_with(self.provider, self.mock_credentials) + mock_auth_instance.validate_credentials.assert_called_once() + + # Verify encryption calls + mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, "test_secret_key_123") + + # Verify database operations + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + @patch("services.auth.api_key_auth_service.db.session") + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + def test_create_provider_auth_validation_failed(self, mock_factory, mock_session): + """Test create provider auth - validation failed""" + # Mock failed auth validation + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = False + mock_factory.return_value = mock_auth_instance + + ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) + + # Verify no database operations when validation fails + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + + @patch("services.auth.api_key_auth_service.db.session") + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_encrypts_api_key(self, mock_encrypter, mock_factory, mock_session): + """Test create provider auth - ensures API key is encrypted""" + # Mock successful auth validation + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + + # Mock encryption + encrypted_key = "encrypted_test_key_123" + mock_encrypter.encrypt_token.return_value = encrypted_key + + # Mock database operations + mock_session.add = Mock() + mock_session.commit = Mock() + + args_copy = self.mock_args.copy() + original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore + + ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy) + + # Verify original key is replaced with encrypted key + assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore + assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore + + # Verify encryption function is called correctly + mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key) + + @patch("services.auth.api_key_auth_service.db.session") + def test_get_auth_credentials_success(self, mock_session): + """Test get auth credentials - success scenario""" + # Mock database query result + mock_binding = Mock() + mock_binding.credentials = json.dumps(self.mock_credentials) + mock_session.query.return_value.filter.return_value.first.return_value = mock_binding + + result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) + + assert result == self.mock_credentials + mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding) + + @patch("services.auth.api_key_auth_service.db.session") + def test_get_auth_credentials_not_found(self, mock_session): + """Test get auth credentials - not found""" + mock_session.query.return_value.filter.return_value.first.return_value = None + + result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) + + assert result is None + + @patch("services.auth.api_key_auth_service.db.session") + def test_get_auth_credentials_filters_correctly(self, mock_session): + """Test get auth credentials - applies correct filters""" + mock_session.query.return_value.filter.return_value.first.return_value = None + + ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) + + # Verify filter conditions are correct + filter_call = mock_session.query.return_value.filter.call_args[0] + assert len(filter_call) == 4 # tenant_id, category, provider, disabled + + @patch("services.auth.api_key_auth_service.db.session") + def test_get_auth_credentials_json_parsing(self, mock_session): + """Test get auth credentials - JSON parsing""" + # Mock credentials with special characters + special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}} + + mock_binding = Mock() + mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False) + mock_session.query.return_value.filter.return_value.first.return_value = mock_binding + + result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) + + assert result == special_credentials + assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" + + @patch("services.auth.api_key_auth_service.db.session") + def test_delete_provider_auth_success(self, mock_session): + """Test delete provider auth - success scenario""" + # Mock database query result + mock_binding = Mock() + mock_session.query.return_value.filter.return_value.first.return_value = mock_binding + + ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id) + + # Verify delete operations + mock_session.delete.assert_called_once_with(mock_binding) + mock_session.commit.assert_called_once() + + @patch("services.auth.api_key_auth_service.db.session") + def test_delete_provider_auth_not_found(self, mock_session): + """Test delete provider auth - not found""" + mock_session.query.return_value.filter.return_value.first.return_value = None + + ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id) + + # Verify no delete operations when not found + mock_session.delete.assert_not_called() + mock_session.commit.assert_not_called() + + @patch("services.auth.api_key_auth_service.db.session") + def test_delete_provider_auth_filters_by_tenant(self, mock_session): + """Test delete provider auth - filters by tenant""" + mock_session.query.return_value.filter.return_value.first.return_value = None + + ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id) + + # Verify filter conditions include tenant_id and binding_id + filter_call = mock_session.query.return_value.filter.call_args[0] + assert len(filter_call) == 2 + + def test_validate_api_key_auth_args_success(self): + """Test API key auth args validation - success scenario""" + # Should not raise any exception + ApiKeyAuthService.validate_api_key_auth_args(self.mock_args) + + def test_validate_api_key_auth_args_missing_category(self): + """Test API key auth args validation - missing category""" + args = self.mock_args.copy() + del args["category"] + + with pytest.raises(ValueError, match="category is required"): + ApiKeyAuthService.validate_api_key_auth_args(args) + + def test_validate_api_key_auth_args_empty_category(self): + """Test API key auth args validation - empty category""" + args = self.mock_args.copy() + args["category"] = "" + + with pytest.raises(ValueError, match="category is required"): + ApiKeyAuthService.validate_api_key_auth_args(args) + + def test_validate_api_key_auth_args_missing_provider(self): + """Test API key auth args validation - missing provider""" + args = self.mock_args.copy() + del args["provider"] + + with pytest.raises(ValueError, match="provider is required"): + ApiKeyAuthService.validate_api_key_auth_args(args) + + def test_validate_api_key_auth_args_empty_provider(self): + """Test API key auth args validation - empty provider""" + args = self.mock_args.copy() + args["provider"] = "" + + with pytest.raises(ValueError, match="provider is required"): + ApiKeyAuthService.validate_api_key_auth_args(args) + + def test_validate_api_key_auth_args_missing_credentials(self): + """Test API key auth args validation - missing credentials""" + args = self.mock_args.copy() + del args["credentials"] + + with pytest.raises(ValueError, match="credentials is required"): + ApiKeyAuthService.validate_api_key_auth_args(args) + + def test_validate_api_key_auth_args_empty_credentials(self): + """Test API key auth args validation - empty credentials""" + args = self.mock_args.copy() + args["credentials"] = None # type: ignore + + with pytest.raises(ValueError, match="credentials is required"): + ApiKeyAuthService.validate_api_key_auth_args(args) + + def test_validate_api_key_auth_args_invalid_credentials_type(self): + """Test API key auth args validation - invalid credentials type""" + args = self.mock_args.copy() + args["credentials"] = "not_a_dict" + + with pytest.raises(ValueError, match="credentials must be a dictionary"): + ApiKeyAuthService.validate_api_key_auth_args(args) + + def test_validate_api_key_auth_args_missing_auth_type(self): + """Test API key auth args validation - missing auth_type""" + args = self.mock_args.copy() + del args["credentials"]["auth_type"] # type: ignore + + with pytest.raises(ValueError, match="auth_type is required"): + ApiKeyAuthService.validate_api_key_auth_args(args) + + def test_validate_api_key_auth_args_empty_auth_type(self): + """Test API key auth args validation - empty auth_type""" + args = self.mock_args.copy() + args["credentials"]["auth_type"] = "" # type: ignore + + with pytest.raises(ValueError, match="auth_type is required"): + ApiKeyAuthService.validate_api_key_auth_args(args) + + @pytest.mark.parametrize( + "malicious_input", + [ + "", + "'; DROP TABLE users; --", + "../../../etc/passwd", + "\\x00\\x00", # null bytes + "A" * 10000, # very long input + ], + ) + def test_validate_api_key_auth_args_malicious_input(self, malicious_input): + """Test API key auth args validation - malicious input""" + args = self.mock_args.copy() + args["category"] = malicious_input + + # Verify parameter validator doesn't crash on malicious input + # Should validate normally rather than raising security-related exceptions + ApiKeyAuthService.validate_api_key_auth_args(args) + + @patch("services.auth.api_key_auth_service.db.session") + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_database_error_handling(self, mock_encrypter, mock_factory, mock_session): + """Test create provider auth - database error handling""" + # Mock successful auth validation + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + + # Mock encryption + mock_encrypter.encrypt_token.return_value = "encrypted_key" + + # Mock database error + mock_session.commit.side_effect = Exception("Database error") + + with pytest.raises(Exception, match="Database error"): + ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) + + @patch("services.auth.api_key_auth_service.db.session") + def test_get_auth_credentials_invalid_json(self, mock_session): + """Test get auth credentials - invalid JSON""" + # Mock database returning invalid JSON + mock_binding = Mock() + mock_binding.credentials = "invalid json content" + mock_session.query.return_value.filter.return_value.first.return_value = mock_binding + + with pytest.raises(json.JSONDecodeError): + ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) + + @patch("services.auth.api_key_auth_service.db.session") + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + def test_create_provider_auth_factory_exception(self, mock_factory, mock_session): + """Test create provider auth - factory exception""" + # Mock factory raising exception + mock_factory.side_effect = Exception("Factory error") + + with pytest.raises(Exception, match="Factory error"): + ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) + + @patch("services.auth.api_key_auth_service.db.session") + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, mock_session): + """Test create provider auth - encryption exception""" + # Mock successful auth validation + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + + # Mock encryption exception + mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") + + with pytest.raises(Exception, match="Encryption error"): + ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) + + def test_validate_api_key_auth_args_none_input(self): + """Test API key auth args validation - None input""" + with pytest.raises(TypeError): + ApiKeyAuthService.validate_api_key_auth_args(None) + + def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self): + """Test API key auth args validation - dict credentials with list auth_type""" + args = self.mock_args.copy() + args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string + + # Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy + # So this should not raise exception, this test should pass + ApiKeyAuthService.validate_api_key_auth_args(args) diff --git a/api/tests/unit_tests/services/auth/test_firecrawl_auth.py b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py new file mode 100644 index 0000000000..ffdf5897ed --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py @@ -0,0 +1,191 @@ +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from services.auth.firecrawl.firecrawl import FirecrawlAuth + + +class TestFirecrawlAuth: + @pytest.fixture + def valid_credentials(self): + """Fixture for valid bearer credentials""" + return {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} + + @pytest.fixture + def auth_instance(self, valid_credentials): + """Fixture for FirecrawlAuth instance with valid credentials""" + return FirecrawlAuth(valid_credentials) + + def test_should_initialize_with_valid_bearer_credentials(self, valid_credentials): + """Test successful initialization with valid bearer credentials""" + auth = FirecrawlAuth(valid_credentials) + assert auth.api_key == "test_api_key_123" + assert auth.base_url == "https://api.firecrawl.dev" + assert auth.credentials == valid_credentials + + def test_should_initialize_with_custom_base_url(self): + """Test initialization with custom base URL""" + credentials = { + "auth_type": "bearer", + "config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"}, + } + auth = FirecrawlAuth(credentials) + assert auth.api_key == "test_api_key_123" + assert auth.base_url == "https://custom.firecrawl.dev" + + @pytest.mark.parametrize( + ("auth_type", "expected_error"), + [ + ("basic", "Invalid auth type, Firecrawl auth type must be Bearer"), + ("x-api-key", "Invalid auth type, Firecrawl auth type must be Bearer"), + ("", "Invalid auth type, Firecrawl auth type must be Bearer"), + ], + ) + def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error): + """Test that non-bearer auth types raise ValueError""" + credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}} + with pytest.raises(ValueError) as exc_info: + FirecrawlAuth(credentials) + assert str(exc_info.value) == expected_error + + @pytest.mark.parametrize( + ("credentials", "expected_error"), + [ + ({"auth_type": "bearer", "config": {}}, "No API key provided"), + ({"auth_type": "bearer"}, "No API key provided"), + ({"auth_type": "bearer", "config": {"api_key": ""}}, "No API key provided"), + ({"auth_type": "bearer", "config": {"api_key": None}}, "No API key provided"), + ], + ) + def test_should_raise_error_for_missing_api_key(self, credentials, expected_error): + """Test that missing or empty API key raises ValueError""" + with pytest.raises(ValueError) as exc_info: + FirecrawlAuth(credentials) + assert str(exc_info.value) == expected_error + + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance): + """Test successful credential validation""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + result = auth_instance.validate_credentials() + + assert result is True + expected_data = { + "url": "https://example.com", + "includePaths": [], + "excludePaths": [], + "limit": 1, + "scrapeOptions": {"onlyMainContent": True}, + } + mock_post.assert_called_once_with( + "https://api.firecrawl.dev/v1/crawl", + headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"}, + json=expected_data, + ) + + @pytest.mark.parametrize( + ("status_code", "error_message"), + [ + (402, "Payment required"), + (409, "Conflict error"), + (500, "Internal server error"), + ], + ) + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance): + """Test handling of various HTTP error codes""" + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.json.return_value = {"error": error_message} + mock_post.return_value = mock_response + + with pytest.raises(Exception) as exc_info: + auth_instance.validate_credentials() + assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}" + + @pytest.mark.parametrize( + ("status_code", "response_text", "has_json_error", "expected_error_contains"), + [ + (403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"), + (404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"), + (401, "Not JSON", True, "Expecting value"), # JSON decode error + ], + ) + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_handle_unexpected_errors( + self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance + ): + """Test handling of unexpected errors with various response formats""" + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.text = response_text + if has_json_error: + mock_response.json.side_effect = Exception("Not JSON") + mock_post.return_value = mock_response + + with pytest.raises(Exception) as exc_info: + auth_instance.validate_credentials() + assert expected_error_contains in str(exc_info.value) + + @pytest.mark.parametrize( + ("exception_type", "exception_message"), + [ + (requests.ConnectionError, "Network error"), + (requests.Timeout, "Request timeout"), + (requests.ReadTimeout, "Read timeout"), + (requests.ConnectTimeout, "Connection timeout"), + ], + ) + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance): + """Test handling of various network-related errors including timeouts""" + mock_post.side_effect = exception_type(exception_message) + + with pytest.raises(exception_type) as exc_info: + auth_instance.validate_credentials() + assert exception_message in str(exc_info.value) + + def test_should_not_expose_api_key_in_error_messages(self): + """Test that API key is not exposed in error messages""" + credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}} + auth = FirecrawlAuth(credentials) + + # Verify API key is stored but not in any error message + assert auth.api_key == "super_secret_key_12345" + + # Test various error scenarios don't expose the key + with pytest.raises(ValueError) as exc_info: + FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}}) + assert "super_secret_key_12345" not in str(exc_info.value) + + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_use_custom_base_url_in_validation(self, mock_post): + """Test that custom base URL is used in validation""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + credentials = { + "auth_type": "bearer", + "config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"}, + } + auth = FirecrawlAuth(credentials) + result = auth.validate_credentials() + + assert result is True + assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl" + + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance): + """Test that timeout errors are handled gracefully with appropriate error message""" + mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds") + + with pytest.raises(requests.Timeout) as exc_info: + auth_instance.validate_credentials() + + # Verify the timeout exception is raised with original message + assert "timed out" in str(exc_info.value) diff --git a/api/tests/unit_tests/services/auth/test_jina_auth.py b/api/tests/unit_tests/services/auth/test_jina_auth.py new file mode 100644 index 0000000000..ccbca5a36f --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_jina_auth.py @@ -0,0 +1,155 @@ +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from services.auth.jina.jina import JinaAuth + + +class TestJinaAuth: + def test_should_initialize_with_valid_bearer_credentials(self): + """Test successful initialization with valid bearer credentials""" + credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} + auth = JinaAuth(credentials) + assert auth.api_key == "test_api_key_123" + assert auth.credentials == credentials + + def test_should_raise_error_for_invalid_auth_type(self): + """Test that non-bearer auth type raises ValueError""" + credentials = {"auth_type": "basic", "config": {"api_key": "test_api_key_123"}} + with pytest.raises(ValueError) as exc_info: + JinaAuth(credentials) + assert str(exc_info.value) == "Invalid auth type, Jina Reader auth type must be Bearer" + + def test_should_raise_error_for_missing_api_key(self): + """Test that missing API key raises ValueError""" + credentials = {"auth_type": "bearer", "config": {}} + with pytest.raises(ValueError) as exc_info: + JinaAuth(credentials) + assert str(exc_info.value) == "No API key provided" + + def test_should_raise_error_for_missing_config(self): + """Test that missing config section raises ValueError""" + credentials = {"auth_type": "bearer"} + with pytest.raises(ValueError) as exc_info: + JinaAuth(credentials) + assert str(exc_info.value) == "No API key provided" + + @patch("services.auth.jina.jina.requests.post") + def test_should_validate_valid_credentials_successfully(self, mock_post): + """Test successful credential validation""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} + auth = JinaAuth(credentials) + result = auth.validate_credentials() + + assert result is True + mock_post.assert_called_once_with( + "https://r.jina.ai", + headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"}, + json={"url": "https://example.com"}, + ) + + @patch("services.auth.jina.jina.requests.post") + def test_should_handle_http_402_error(self, mock_post): + """Test handling of 402 Payment Required error""" + mock_response = MagicMock() + mock_response.status_code = 402 + mock_response.json.return_value = {"error": "Payment required"} + mock_post.return_value = mock_response + + credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} + auth = JinaAuth(credentials) + + with pytest.raises(Exception) as exc_info: + auth.validate_credentials() + assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required" + + @patch("services.auth.jina.jina.requests.post") + def test_should_handle_http_409_error(self, mock_post): + """Test handling of 409 Conflict error""" + mock_response = MagicMock() + mock_response.status_code = 409 + mock_response.json.return_value = {"error": "Conflict error"} + mock_post.return_value = mock_response + + credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} + auth = JinaAuth(credentials) + + with pytest.raises(Exception) as exc_info: + auth.validate_credentials() + assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error" + + @patch("services.auth.jina.jina.requests.post") + def test_should_handle_http_500_error(self, mock_post): + """Test handling of 500 Internal Server Error""" + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.json.return_value = {"error": "Internal server error"} + mock_post.return_value = mock_response + + credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} + auth = JinaAuth(credentials) + + with pytest.raises(Exception) as exc_info: + auth.validate_credentials() + assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error" + + @patch("services.auth.jina.jina.requests.post") + def test_should_handle_unexpected_error_with_text_response(self, mock_post): + """Test handling of unexpected errors with text response""" + mock_response = MagicMock() + mock_response.status_code = 403 + mock_response.text = '{"error": "Forbidden"}' + mock_response.json.side_effect = Exception("Not JSON") + mock_post.return_value = mock_response + + credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} + auth = JinaAuth(credentials) + + with pytest.raises(Exception) as exc_info: + auth.validate_credentials() + assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden" + + @patch("services.auth.jina.jina.requests.post") + def test_should_handle_unexpected_error_without_text(self, mock_post): + """Test handling of unexpected errors without text response""" + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.text = "" + mock_response.json.side_effect = Exception("Not JSON") + mock_post.return_value = mock_response + + credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} + auth = JinaAuth(credentials) + + with pytest.raises(Exception) as exc_info: + auth.validate_credentials() + assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404" + + @patch("services.auth.jina.jina.requests.post") + def test_should_handle_network_errors(self, mock_post): + """Test handling of network connection errors""" + mock_post.side_effect = requests.ConnectionError("Network error") + + credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} + auth = JinaAuth(credentials) + + with pytest.raises(requests.ConnectionError): + auth.validate_credentials() + + def test_should_not_expose_api_key_in_error_messages(self): + """Test that API key is not exposed in error messages""" + credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}} + auth = JinaAuth(credentials) + + # Verify API key is stored but not in any error message + assert auth.api_key == "super_secret_key_12345" + + # Test various error scenarios don't expose the key + with pytest.raises(ValueError) as exc_info: + JinaAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}}) + assert "super_secret_key_12345" not in str(exc_info.value) diff --git a/api/tests/unit_tests/services/auth/test_watercrawl_auth.py b/api/tests/unit_tests/services/auth/test_watercrawl_auth.py new file mode 100644 index 0000000000..bacf0b24ea --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_watercrawl_auth.py @@ -0,0 +1,205 @@ +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from services.auth.watercrawl.watercrawl import WatercrawlAuth + + +class TestWatercrawlAuth: + @pytest.fixture + def valid_credentials(self): + """Fixture for valid x-api-key credentials""" + return {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123"}} + + @pytest.fixture + def auth_instance(self, valid_credentials): + """Fixture for WatercrawlAuth instance with valid credentials""" + return WatercrawlAuth(valid_credentials) + + def test_should_initialize_with_valid_x_api_key_credentials(self, valid_credentials): + """Test successful initialization with valid x-api-key credentials""" + auth = WatercrawlAuth(valid_credentials) + assert auth.api_key == "test_api_key_123" + assert auth.base_url == "https://app.watercrawl.dev" + assert auth.credentials == valid_credentials + + def test_should_initialize_with_custom_base_url(self): + """Test initialization with custom base URL""" + credentials = { + "auth_type": "x-api-key", + "config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"}, + } + auth = WatercrawlAuth(credentials) + assert auth.api_key == "test_api_key_123" + assert auth.base_url == "https://custom.watercrawl.dev" + + @pytest.mark.parametrize( + ("auth_type", "expected_error"), + [ + ("bearer", "Invalid auth type, WaterCrawl auth type must be x-api-key"), + ("basic", "Invalid auth type, WaterCrawl auth type must be x-api-key"), + ("", "Invalid auth type, WaterCrawl auth type must be x-api-key"), + ], + ) + def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error): + """Test that non-x-api-key auth types raise ValueError""" + credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}} + with pytest.raises(ValueError) as exc_info: + WatercrawlAuth(credentials) + assert str(exc_info.value) == expected_error + + @pytest.mark.parametrize( + ("credentials", "expected_error"), + [ + ({"auth_type": "x-api-key", "config": {}}, "No API key provided"), + ({"auth_type": "x-api-key"}, "No API key provided"), + ({"auth_type": "x-api-key", "config": {"api_key": ""}}, "No API key provided"), + ({"auth_type": "x-api-key", "config": {"api_key": None}}, "No API key provided"), + ], + ) + def test_should_raise_error_for_missing_api_key(self, credentials, expected_error): + """Test that missing or empty API key raises ValueError""" + with pytest.raises(ValueError) as exc_info: + WatercrawlAuth(credentials) + assert str(exc_info.value) == expected_error + + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance): + """Test successful credential validation""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + result = auth_instance.validate_credentials() + + assert result is True + mock_get.assert_called_once_with( + "https://app.watercrawl.dev/api/v1/core/crawl-requests/", + headers={"Content-Type": "application/json", "X-API-KEY": "test_api_key_123"}, + ) + + @pytest.mark.parametrize( + ("status_code", "error_message"), + [ + (402, "Payment required"), + (409, "Conflict error"), + (500, "Internal server error"), + ], + ) + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance): + """Test handling of various HTTP error codes""" + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.json.return_value = {"error": error_message} + mock_get.return_value = mock_response + + with pytest.raises(Exception) as exc_info: + auth_instance.validate_credentials() + assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}" + + @pytest.mark.parametrize( + ("status_code", "response_text", "has_json_error", "expected_error_contains"), + [ + (403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"), + (404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"), + (401, "Not JSON", True, "Expecting value"), # JSON decode error + ], + ) + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_handle_unexpected_errors( + self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance + ): + """Test handling of unexpected errors with various response formats""" + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.text = response_text + if has_json_error: + mock_response.json.side_effect = Exception("Not JSON") + mock_get.return_value = mock_response + + with pytest.raises(Exception) as exc_info: + auth_instance.validate_credentials() + assert expected_error_contains in str(exc_info.value) + + @pytest.mark.parametrize( + ("exception_type", "exception_message"), + [ + (requests.ConnectionError, "Network error"), + (requests.Timeout, "Request timeout"), + (requests.ReadTimeout, "Read timeout"), + (requests.ConnectTimeout, "Connection timeout"), + ], + ) + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance): + """Test handling of various network-related errors including timeouts""" + mock_get.side_effect = exception_type(exception_message) + + with pytest.raises(exception_type) as exc_info: + auth_instance.validate_credentials() + assert exception_message in str(exc_info.value) + + def test_should_not_expose_api_key_in_error_messages(self): + """Test that API key is not exposed in error messages""" + credentials = {"auth_type": "x-api-key", "config": {"api_key": "super_secret_key_12345"}} + auth = WatercrawlAuth(credentials) + + # Verify API key is stored but not in any error message + assert auth.api_key == "super_secret_key_12345" + + # Test various error scenarios don't expose the key + with pytest.raises(ValueError) as exc_info: + WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}) + assert "super_secret_key_12345" not in str(exc_info.value) + + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_use_custom_base_url_in_validation(self, mock_get): + """Test that custom base URL is used in validation""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + credentials = { + "auth_type": "x-api-key", + "config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"}, + } + auth = WatercrawlAuth(credentials) + result = auth.validate_credentials() + + assert result is True + assert mock_get.call_args[0][0] == "https://custom.watercrawl.dev/api/v1/core/crawl-requests/" + + @pytest.mark.parametrize( + ("base_url", "expected_url"), + [ + ("https://app.watercrawl.dev", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"), + ("https://app.watercrawl.dev/", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"), + ("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"), + ], + ) + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url): + """Test that urljoin is used correctly for URL construction with various base URLs""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + credentials = {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123", "base_url": base_url}} + auth = WatercrawlAuth(credentials) + auth.validate_credentials() + + # Verify the correct URL was called + assert mock_get.call_args[0][0] == expected_url + + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance): + """Test that timeout errors are handled gracefully with appropriate error message""" + mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds") + + with pytest.raises(requests.Timeout) as exc_info: + auth_instance.validate_credentials() + + # Verify the timeout exception is raised with original message + assert "timed out" in str(exc_info.value) diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py index 87b46f213b..7c40b1e556 100644 --- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -102,17 +102,16 @@ class TestDatasetServiceUpdateDataset: patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.datetime") as mock_datetime, + patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, ): current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC + mock_naive_utc_now.return_value = current_time yield { "get_dataset": mock_get_dataset, "check_permission": mock_check_perm, "db_session": mock_db, - "datetime": mock_datetime, + "naive_utc_now": mock_naive_utc_now, "current_time": current_time, } @@ -292,7 +291,7 @@ class TestDatasetServiceUpdateDataset: "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + "updated_at": mock_dataset_service_dependencies["current_time"], } self._assert_database_update_called( @@ -327,7 +326,7 @@ class TestDatasetServiceUpdateDataset: "indexing_technique": "high_quality", "retrieval_model": "new_model", "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + "updated_at": mock_dataset_service_dependencies["current_time"], } actual_call_args = mock_dataset_service_dependencies[ @@ -365,7 +364,7 @@ class TestDatasetServiceUpdateDataset: "collection_binding_id": None, "retrieval_model": "new_model", "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + "updated_at": mock_dataset_service_dependencies["current_time"], } self._assert_database_update_called( @@ -422,7 +421,7 @@ class TestDatasetServiceUpdateDataset: "collection_binding_id": "binding-456", "retrieval_model": "new_model", "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + "updated_at": mock_dataset_service_dependencies["current_time"], } self._assert_database_update_called( @@ -463,7 +462,7 @@ class TestDatasetServiceUpdateDataset: "collection_binding_id": "binding-123", "retrieval_model": "new_model", "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + "updated_at": mock_dataset_service_dependencies["current_time"], } self._assert_database_update_called( @@ -525,7 +524,7 @@ class TestDatasetServiceUpdateDataset: "collection_binding_id": "binding-789", "retrieval_model": "new_model", "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + "updated_at": mock_dataset_service_dependencies["current_time"], } self._assert_database_update_called( @@ -568,7 +567,7 @@ class TestDatasetServiceUpdateDataset: "collection_binding_id": "binding-123", "retrieval_model": "new_model", "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + "updated_at": mock_dataset_service_dependencies["current_time"], } self._assert_database_update_called( diff --git a/docker/.env.example b/docker/.env.example index a05141569b..6149f63165 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -283,11 +283,12 @@ REDIS_CLUSTERS_PASSWORD= # Celery Configuration # ------------------------------ -# Use redis as the broker, and redis db 1 for celery broker. -# Format as follows: `redis://:@:/` +# Use standalone redis as the broker, and redis db 1 for celery broker. (redis_username is usually set by defualt as empty) +# Format as follows: `redis://:@:/`. # Example: redis://:difyai123456@redis:6379/1 -# If use Redis Sentinel, format as follows: `sentinel://:@:/` -# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1 +# If use Redis Sentinel, format as follows: `sentinel://:@:/` +# For high availability, you can configure multiple Sentinel nodes (if provided) separated by semicolons like below example: +# Example: sentinel://:difyai123456@localhost:26379/1;sentinel://:difyai12345@localhost:26379/1;sentinel://:difyai12345@localhost:26379/1 CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1 CELERY_BACKEND=redis BROKER_USE_SSL=false @@ -412,6 +413,8 @@ SUPABASE_URL=your-server-url # The type of vector store to use. # Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. VECTOR_STORE=weaviate +# Prefix used to create collection name in vector database +VECTOR_INDEX_NAME_PREFIX=Vector_index # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. WEAVIATE_ENDPOINT=http://weaviate:8080 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 5962adb079..1271d6d464 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -136,6 +136,7 @@ x-shared-env: &shared-api-worker-env SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key} SUPABASE_URL: ${SUPABASE_URL:-your-server-url} VECTOR_STORE: ${VECTOR_STORE:-weaviate} + VECTOR_INDEX_NAME_PREFIX: ${VECTOR_INDEX_NAME_PREFIX:-Vector_index} WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333} diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py index b557a9ce95..d00c207afa 100644 --- a/sdks/python-client/dify_client/__init__.py +++ b/sdks/python-client/dify_client/__init__.py @@ -1 +1,7 @@ -from dify_client.client import ChatClient, CompletionClient, WorkflowClient, KnowledgeBaseClient, DifyClient +from dify_client.client import ( + ChatClient, + CompletionClient, + WorkflowClient, + KnowledgeBaseClient, + DifyClient, +) diff --git a/web/Dockerfile b/web/Dockerfile index 93eef59815..d59039528c 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -6,7 +6,7 @@ LABEL maintainer="takatost@gmail.com" # RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories RUN apk add --no-cache tzdata -RUN npm install -g pnpm@10.11.1 +RUN npm install -g pnpm@10.13.1 ENV PNPM_HOME="/pnpm" ENV PATH="$PNPM_HOME:$PATH" diff --git a/web/app/(commonLayout)/datasets/NewDatasetCard.tsx b/web/app/(commonLayout)/datasets/NewDatasetCard.tsx new file mode 100644 index 0000000000..62f6a34be0 --- /dev/null +++ b/web/app/(commonLayout)/datasets/NewDatasetCard.tsx @@ -0,0 +1,41 @@ +'use client' +import { useTranslation } from 'react-i18next' +import { + RiAddLine, + RiArrowRightLine, +} from '@remixicon/react' +import Link from 'next/link' + +type CreateAppCardProps = { + ref?: React.Ref +} + +const CreateAppCard = ({ ref }: CreateAppCardProps) => { + const { t } = useTranslation() + + return ( +
+ +
+
+ +
+
{t('dataset.createDataset')}
+
+ +
{t('dataset.createDatasetIntro')}
+ +
{t('dataset.connectDataset')}
+ + +
+ ) +} + +CreateAppCard.displayName = 'CreateAppCard' + +export default CreateAppCard diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index d07e2a99d9..64186a1b10 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -1,6 +1,6 @@ import React from 'react' import type { ReactNode } from 'react' -import SwrInitor from '@/app/components/swr-initor' +import SwrInitializer from '@/app/components/swr-initializer' import { AppContextProvider } from '@/context/app-context' import GA, { GaType } from '@/app/components/base/ga' import HeaderWrapper from '@/app/components/header/header-wrapper' @@ -13,7 +13,7 @@ const Layout = ({ children }: { children: ReactNode }) => { return ( <> - + @@ -26,7 +26,7 @@ const Layout = ({ children }: { children: ReactNode }) => { - + ) } diff --git a/web/app/account/account-page/index.tsx b/web/app/account/account-page/index.tsx index 55fa2983dd..47b8f045d2 100644 --- a/web/app/account/account-page/index.tsx +++ b/web/app/account/account-page/index.tsx @@ -1,5 +1,6 @@ 'use client' import { useState } from 'react' +import useSWR from 'swr' import { useTranslation } from 'react-i18next' import { RiGraduationCapFill, @@ -22,6 +23,8 @@ import PremiumBadge from '@/app/components/base/premium-badge' import { useGlobalPublicStore } from '@/context/global-public-context' import EmailChangeModal from './email-change-modal' import { validPassword } from '@/config' +import { fetchAppList } from '@/service/apps' +import type { App } from '@/types/app' const titleClassName = ` system-sm-semibold text-text-secondary @@ -33,7 +36,9 @@ const descriptionClassName = ` export default function AccountPage() { const { t } = useTranslation() const { systemFeatures } = useGlobalPublicStore() - const { mutateUserProfile, userProfile, apps } = useAppContext() + const { data: appList } = useSWR({ url: '/apps', params: { page: 1, limit: 100, name: '' } }, fetchAppList) + const apps = appList?.data || [] + const { mutateUserProfile, userProfile } = useAppContext() const { isEducationAccount } = useProviderContext() const { notify } = useContext(ToastContext) const [editNameModalVisible, setEditNameModalVisible] = useState(false) @@ -202,7 +207,7 @@ export default function AccountPage() { {!!apps.length && ( ({ ...app, key: app.id, name: app.name }))} + items={apps.map((app: App) => ({ ...app, key: app.id, name: app.name }))} renderItem={renderAppItem} wrapperClassName='mt-2' /> diff --git a/web/app/account/layout.tsx b/web/app/account/layout.tsx index e74716fb3b..b3225b5341 100644 --- a/web/app/account/layout.tsx +++ b/web/app/account/layout.tsx @@ -1,7 +1,7 @@ import React from 'react' import type { ReactNode } from 'react' import Header from './header' -import SwrInitor from '@/app/components/swr-initor' +import SwrInitor from '@/app/components/swr-initializer' import { AppContextProvider } from '@/context/app-context' import GA, { GaType } from '@/app/components/base/ga' import HeaderWrapper from '@/app/components/header/header-wrapper' diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 8dd2108bfd..8c629a30a2 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -1,6 +1,6 @@ import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' -import { useContext, useContextSelector } from 'use-context-selector' +import { useContext } from 'use-context-selector' import React, { useCallback, useState } from 'react' import { RiDeleteBinLine, @@ -15,7 +15,7 @@ import AppIcon from '../base/app-icon' import cn from '@/utils/classnames' import { useStore as useAppStore } from '@/app/components/app/store' import { ToastContext } from '@/app/components/base/toast' -import AppsContext, { useAppContext } from '@/context/app-context' +import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' @@ -73,11 +73,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx const [showImportDSLModal, setShowImportDSLModal] = useState(false) const [secretEnvList, setSecretEnvList] = useState([]) - const mutateApps = useContextSelector( - AppsContext, - state => state.mutateApps, - ) - const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, icon_type, @@ -106,12 +101,11 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx message: t('app.editDone'), }) setAppDetail(app) - mutateApps() } catch { notify({ type: 'error', message: t('app.editFailed') }) } - }, [appDetail, mutateApps, notify, setAppDetail, t]) + }, [appDetail, notify, setAppDetail, t]) const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => { if (!appDetail) @@ -131,7 +125,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx message: t('app.newApp.appCreated'), }) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') - mutateApps() onPlanInfoChanged() getRedirection(true, newApp, replace) } @@ -186,7 +179,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx try { await deleteApp(appDetail.id) notify({ type: 'success', message: t('app.appDeleted') }) - mutateApps() onPlanInfoChanged() setAppDetail() replace('/apps') @@ -198,7 +190,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx }) } setShowConfirmDelete(false) - }, [appDetail, mutateApps, notify, onPlanInfoChanged, replace, setAppDetail, t]) + }, [appDetail, notify, onPlanInfoChanged, replace, setAppDetail, t]) const { isCurrentWorkspaceEditor } = useAppContext() diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index 99dc32dfa2..7ba22907dd 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -13,7 +13,6 @@ import Loading from '@/app/components/base/loading' import Badge from '@/app/components/base/badge' import { useKnowledge } from '@/hooks/use-knowledge' import cn from '@/utils/classnames' -import { basePath } from '@/utils/var' import AppIcon from '@/app/components/base/app-icon' export type ISelectDataSetProps = { @@ -113,7 +112,7 @@ const SelectDataSet: FC = ({ }} > {t('appDebug.feature.dataSet.noDataSet')} - {t('appDebug.feature.dataSet.toCreate')} + {t('appDebug.feature.dataSet.toCreate')} )} diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index f0a0da41a5..c37f7b051a 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -4,7 +4,7 @@ import { useCallback, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' -import { useContext, useContextSelector } from 'use-context-selector' +import { useContext } from 'use-context-selector' import { RiArrowRightLine, RiArrowRightSLine, RiCommandLine, RiCornerDownLeftLine, RiExchange2Fill } from '@remixicon/react' import Link from 'next/link' import { useDebounceFn, useKeyPress } from 'ahooks' @@ -15,7 +15,7 @@ import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import cn from '@/utils/classnames' import { basePath } from '@/utils/var' -import AppsContext, { useAppContext } from '@/context/app-context' +import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import { ToastContext } from '@/app/components/base/toast' import type { AppMode } from '@/types/app' @@ -41,7 +41,6 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) const { t } = useTranslation() const { push } = useRouter() const { notify } = useContext(ToastContext) - const mutateApps = useContextSelector(AppsContext, state => state.mutateApps) const [appMode, setAppMode] = useState('advanced-chat') const [appIcon, setAppIcon] = useState({ type: 'emoji', icon: '🤖', background: '#FFEAD5' }) @@ -80,7 +79,6 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) notify({ type: 'success', message: t('app.newApp.appCreated') }) onSuccess() onClose() - mutateApps() localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') getRedirection(isCurrentWorkspaceEditor, app, push) } @@ -88,7 +86,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) } isCreatingRef.current = false - }, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, mutateApps, push, isCurrentWorkspaceEditor]) + }, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, push, isCurrentWorkspaceEditor]) const { run: handleCreateApp } = useDebounceFn(onCreate, { wait: 300 }) useKeyPress(['meta.enter', 'ctrl.enter'], () => { @@ -298,7 +296,7 @@ function AppTypeCard({ icon, title, description, active, onClick }: AppTypeCardP > {icon}
{title}
-
{description}
+
{description}
} diff --git a/web/app/components/app/overview/embedded/index.tsx b/web/app/components/app/overview/embedded/index.tsx index b48eac5458..9d97eae38d 100644 --- a/web/app/components/app/overview/embedded/index.tsx +++ b/web/app/components/app/overview/embedded/index.tsx @@ -90,10 +90,10 @@ const Embedded = ({ siteInfo, isShow, onClose, appBaseUrl, accessToken, classNam const [option, setOption] = useState