merge main

This commit is contained in:
zxhlyh 2025-07-21 17:45:26 +08:00
commit 4d36e784b7
117 changed files with 1639 additions and 419 deletions

View File

@ -1,6 +1,6 @@
#!/bin/bash
npm add -g pnpm@10.11.1
npm add -g pnpm@10.13.1
cd web && pnpm install
pipx install uv
@ -12,3 +12,4 @@ echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f do
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
source /home/vscode/.bashrc

View File

@ -28,7 +28,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
uses: tj-actions/changed-files@v46
with:
files: |
api/**
@ -75,7 +75,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
uses: tj-actions/changed-files@v46
with:
files: web/**
@ -113,7 +113,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
@ -144,7 +144,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
uses: tj-actions/changed-files@v46
with:
files: |
**.sh
@ -152,13 +152,15 @@ jobs:
**.yml
**Dockerfile
dev/**
.editorconfig
- name: Super-linter
uses: super-linter/super-linter/slim@v7
uses: super-linter/super-linter/slim@v8
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning
DEFAULT_BRANCH: main
DEFAULT_BRANCH: origin/main
EDITORCONFIG_FILE_NAME: editorconfig-checker.json
FILTER_REGEX_INCLUDE: pnpm-lock.yaml
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true
@ -168,16 +170,6 @@ jobs:
# FIXME: temporarily disabled until api-docker.yaml's run script is fixed for shellcheck
# VALIDATE_GITHUB_ACTIONS: true
VALIDATE_DOCKERFILE_HADOLINT: true
VALIDATE_EDITORCONFIG: true
VALIDATE_XML: true
VALIDATE_YAML: true
- name: EditorConfig checks
uses: super-linter/super-linter/slim@v7
env:
DEFAULT_BRANCH: main
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true
IGNORE_GITIGNORED_FILES: true
# EditorConfig validation
VALIDATE_EDITORCONFIG: true
EDITORCONFIG_FILE_NAME: editorconfig-checker.json

View File

@ -27,7 +27,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
uses: tj-actions/changed-files@v46
with:
files: web/**

View File

@ -142,8 +142,10 @@ WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
# Weaviate configuration
WEAVIATE_ENDPOINT=http://localhost:8080

View File

@ -85,6 +85,11 @@ class VectorStoreConfig(BaseSettings):
default=False,
)
VECTOR_INDEX_NAME_PREFIX: Optional[str] = Field(
description="Prefix used to create collection name in vector database",
default="Vector_index",
)
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(

View File

@ -1,4 +1,4 @@
from datetime import UTC, datetime
from datetime import datetime
import pytz # pip install pytz
from flask_login import current_user
@ -19,6 +19,7 @@ from fields.conversation_fields import (
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString
from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation
@ -315,7 +316,7 @@ def _get_conversation(app_model, conversation_id):
raise NotFound("Conversation Not Exists.")
if not conversation.read_at:
conversation.read_at = datetime.now(UTC).replace(tzinfo=None)
conversation.read_at = naive_utc_now()
conversation.read_account_id = current_user.id
db.session.commit()

View File

@ -1,5 +1,3 @@
from datetime import UTC, datetime
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@ -10,6 +8,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Site
@ -77,7 +76,7 @@ class AppSite(Resource):
setattr(site, attr_name, value)
site.updated_by = current_user.id
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
site.updated_at = naive_utc_now()
db.session.commit()
return site
@ -101,7 +100,7 @@ class AppSiteAccessTokenReset(Resource):
site.code = Site.generate_code(16)
site.updated_by = current_user.id
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
site.updated_at = naive_utc_now()
db.session.commit()
return site

View File

@ -1,5 +1,3 @@
import datetime
from flask import request
from flask_restful import Resource, reqparse
@ -7,6 +5,7 @@ from constants.languages import supported_language
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus
from services.account_service import AccountService, RegisterService
@ -65,7 +64,7 @@ class ActivateApi(Resource):
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
account.initialized_at = naive_utc_now()
db.session.commit()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))

View File

@ -1,5 +1,4 @@
import logging
from datetime import UTC, datetime
from typing import Optional
import requests
@ -13,6 +12,7 @@ from configs import dify_config
from constants.languages import languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
@ -110,7 +110,7 @@ class OAuthCallback(Resource):
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
account.initialized_at = naive_utc_now()
db.session.commit()
try:

View File

@ -1,4 +1,3 @@
import datetime
import json
from flask import request
@ -15,6 +14,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
@ -88,7 +88,7 @@ class DataSourceApi(Resource):
if action == "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
@ -97,7 +97,7 @@ class DataSourceApi(Resource):
if action == "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:

View File

@ -1,7 +1,6 @@
import json
import logging
from argparse import ArgumentTypeError
from datetime import UTC, datetime
from typing import cast
from flask import request
@ -50,6 +49,7 @@ from fields.document_fields import (
document_status_fields,
document_with_segments_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
@ -752,7 +752,7 @@ class DocumentProcessingApi(DocumentResource):
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
document.paused_at = datetime.now(UTC).replace(tzinfo=None)
document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
@ -832,7 +832,7 @@ class DocumentMetadataApi(DocumentResource):
document.doc_metadata[key] = value
document.doc_type = doc_type
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
document.updated_at = naive_utc_now()
db.session.commit()
return {"result": "success", "message": "Document metadata updated."}, 200

View File

@ -1,5 +1,4 @@
import logging
from datetime import UTC, datetime
from flask_login import current_user
from flask_restful import reqparse
@ -27,6 +26,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -51,7 +51,7 @@ class CompletionApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
installed_app.last_used_at = naive_utc_now()
db.session.commit()
try:
@ -111,7 +111,7 @@ class ChatApi(InstalledAppResource):
args["auto_generate_name"] = False
installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
installed_app.last_used_at = naive_utc_now()
db.session.commit()
try:

View File

@ -1,5 +1,4 @@
import logging
from datetime import UTC, datetime
from typing import Any
from flask import request
@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
@ -122,7 +122,7 @@ class InstalledAppsListApi(Resource):
tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id,
is_pinned=False,
last_used_at=datetime.now(UTC).replace(tzinfo=None),
last_used_at=naive_utc_now(),
)
db.session.add(new_installed_app)
db.session.commit()

View File

@ -1,5 +1,3 @@
import datetime
import pytz
from flask import request
from flask_login import current_user
@ -35,6 +33,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, email, extract_remote_ip, timezone
from libs.login import login_required
from models import AccountIntegrate, InvitationCode
@ -80,7 +79,7 @@ class AccountInitApi(Resource):
raise InvalidInvitationCodeError()
invitation_code.status = "used"
invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
invitation_code.used_at = naive_utc_now()
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id
@ -88,7 +87,7 @@ class AccountInitApi(Resource):
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = "active"
account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
account.initialized_at = naive_utc_now()
db.session.commit()
return {"result": "success"}

View File

@ -29,7 +29,7 @@ from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_mange_service import MCPToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService

View File

@ -1,6 +1,6 @@
import time
from collections.abc import Callable
from datetime import UTC, datetime, timedelta
from datetime import timedelta
from enum import Enum
from functools import wraps
from typing import Optional
@ -15,6 +15,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
@ -256,7 +257,7 @@ def validate_and_get_api_token(scope: str | None = None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
current_time = datetime.now(UTC).replace(tzinfo=None)
current_time = naive_utc_now()
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (

View File

@ -1,7 +1,6 @@
import json
import logging
from collections.abc import Generator
from datetime import UTC, datetime
from typing import Optional, Union, cast
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
@ -25,6 +24,7 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
@ -184,7 +184,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.commit()
db.session.refresh(conversation)
else:
conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
conversation.updated_at = naive_utc_now()
db.session.commit()
message = Message(

View File

@ -7,6 +7,7 @@ from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
@ -44,11 +45,44 @@ def to_prompt_message_content(
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> PromptMessageContentUnionTypes:
"""
Convert a file to prompt message content.
This function converts files to their appropriate prompt message content types.
For supported file types (IMAGE, AUDIO, VIDEO, DOCUMENT), it creates the
corresponding message content with proper encoding/URL.
For unsupported file types, instead of raising an error, it returns a
TextPromptMessageContent with a descriptive message about the file.
Args:
f: The file to convert
image_detail_config: Optional detail configuration for image files
Returns:
PromptMessageContentUnionTypes: The appropriate message content type
Raises:
ValueError: If file extension or mime_type is missing
"""
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
raise ValueError("Missing file mime_type")
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}
# Check if file type is supported
if f.type not in prompt_class_map:
# For unsupported file types, return a text description
return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]")
# Process supported file types
params = {
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
@ -58,17 +92,7 @@ def to_prompt_message_content(
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}
try:
return prompt_class_map[f.type].model_validate(params)
except KeyError:
raise ValueError(f"file type {f.type} is not supported")
return prompt_class_map[f.type].model_validate(params)
def download(f: File, /):

View File

@ -8,7 +8,7 @@ from core.mcp.types import (
OAuthTokens,
)
from models.tools import MCPToolProvider
from services.tools.mcp_tools_mange_service import MCPToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0"

View File

@ -68,15 +68,17 @@ class MCPClient:
}
parsed_url = urlparse(self.server_url)
path = parsed_url.path
method_name = path.rstrip("/").split("/")[-1] if path else ""
try:
path = parsed_url.path or ""
method_name = path.removesuffix("/").lower()
if method_name in connection_methods:
client_factory = connection_methods[method_name]
self.connect_server(client_factory, method_name)
except KeyError:
else:
try:
logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.")
self.connect_server(sse_client, "sse")
except MCPConnectionError:
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")
def connect_server(
@ -91,7 +93,7 @@ class MCPClient:
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if self._streams_context is None:
if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
@ -141,10 +143,11 @@ class MCPClient:
try:
# ExitStack will handle proper cleanup of all managed context managers
self.exit_stack.close()
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")
finally:
self._session = None
self._session_context = None
self._streams_context = None
self._initialized = False
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

View File

@ -3,7 +3,7 @@ import json
import logging
import os
from datetime import datetime, timedelta
from typing import Optional, Union, cast
from typing import Any, Optional, Union, cast
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry import trace
@ -142,11 +142,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
if trace_info.message_data is None:
return
workflow_metadata = {
"workflow_id": trace_info.workflow_run_id or "",
"workflow_run_id": trace_info.workflow_run_id or "",
"message_id": trace_info.message_id or "",
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
"status": trace_info.workflow_run_status or "",
@ -156,7 +153,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
workflow_metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
trace_id = uuid_to_trace_id(trace_info.workflow_run_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
@ -213,7 +210,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
if model:
node_metadata["ls_model_name"] = model
outputs = json.loads(node_execution.outputs).get("usage", {})
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
if usage_data:
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
@ -236,31 +233,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(created_at),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if node_execution.node_type == "llm":
llm_attributes: dict[str, Any] = {
SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
}
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
if model:
node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
outputs = json.loads(node_execution.outputs).get("usage", {})
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
outputs = (
json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
)
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
)
if usage_data:
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
)
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
)
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0)
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get(
"completion_tokens", 0
)
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
node_span.set_attributes(llm_attributes)
finally:
node_span.end(end_time=datetime_to_nanos(finished_at))
finally:
@ -352,25 +352,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
if isinstance(trace_info.inputs, list):
for i, msg in enumerate(trace_info.inputs):
if isinstance(msg, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
"role", "user"
)
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(trace_info.inputs, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(trace_info.inputs, str):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
@ -724,3 +706,24 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
.all()
)
return workflow_nodes
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
"""Helper method to construct LLM attributes with passed prompts."""
attributes = {}
if isinstance(prompts, list):
for i, msg in enumerate(prompts):
if isinstance(msg, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(prompts, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(prompts, str):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
return attributes

View File

@ -233,6 +233,12 @@ class AnalyticdbVectorOpenAPI:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
@ -245,7 +251,7 @@ class AnalyticdbVectorOpenAPI:
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
filter=where_clause,
)
response = self._client.query_collection_data(request)
documents = []
@ -265,6 +271,11 @@ class AnalyticdbVectorOpenAPI:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
@ -277,7 +288,7 @@ class AnalyticdbVectorOpenAPI:
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
filter=where_clause,
)
response = self._client.query_collection_data(request)
documents = []

View File

@ -147,10 +147,17 @@ class ElasticSearchVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {"match": {Field.CONTENT_KEY.value: query}}
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore
query_str = {
"bool": {
"must": {"match": {Field.CONTENT_KEY.value: query}},
"filter": {"terms": {"metadata.document_id": document_ids_filter}},
}
}
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:

View File

@ -102,6 +102,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = text.split()
else:
splits = text.split(separator)
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
else:
splits = list(text)
splits = [s for s in splits if (s not in {"", "\n"})]

View File

@ -21,7 +21,7 @@ from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
from services.tools.mcp_tools_mange_service import MCPToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity

View File

@ -270,7 +270,14 @@ class AgentNode(BaseNode):
)
extra = tool.get("extra", {})
runtime_variable_pool = variable_pool if self._node_data.version != "1" else None
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
# But for backward compatibility with historical data
# this version field judgment is still preserved here.
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version != "1":
runtime_variable_pool = variable_pool
tool_runtime = ToolManager.get_agent_tool_runtime(
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
)

View File

@ -13,6 +13,10 @@ class AgentNodeData(BaseNodeData):
agent_strategy_name: str
agent_strategy_label: str # redundancy
memory: MemoryConfig | None = None
# The version of the tool parameter.
# If this value is None, it indicates this is a previous version
# and requires using the legacy parameter parsing rules.
tool_node_version: str | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]

View File

@ -117,7 +117,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: ModelConfig
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
vision: VisionConfig = Field(default_factory=VisionConfig)

View File

@ -509,6 +509,8 @@ class KnowledgeRetrievalNode(BaseNode):
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance and fetch model config
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
# fetch prompt messages
@ -701,7 +703,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode(node_data.metadata_model_config.mode)
model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
input_text = query
prompt_messages: list[LLMNodeChatModelMessage] = []

View File

@ -75,6 +75,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": ToolNode,
"1": ToolNode,
},
@ -125,6 +128,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
},
NodeType.AGENT: {
LATEST_VERSION: AgentNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": AgentNode,
"1": AgentNode,
},

View File

@ -59,6 +59,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
return typ
tool_parameters: dict[str, ToolInput]
# The version of the tool parameter.
# If this value is None, it indicates this is a previous version
# and requires using the legacy parameter parsing rules.
tool_node_version: str | None = None
@field_validator("tool_parameters", mode="before")
@classmethod

View File

@ -70,7 +70,13 @@ class ToolNode(BaseNode):
try:
from core.tools.tool_manager import ToolManager
variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
# But for backward compatibility with historical data
# this version field judgment is still preserved here.
variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version != "1":
variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
)

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from datetime import datetime
from typing import Any, Optional, Union
from uuid import uuid4
@ -71,7 +71,7 @@ class WorkflowCycleManager:
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=inputs,
started_at=datetime.now(UTC).replace(tzinfo=None),
started_at=naive_utc_now(),
)
return self._save_and_cache_workflow_execution(execution)
@ -356,7 +356,7 @@ class WorkflowCycleManager:
created_at: Optional[datetime] = None,
) -> WorkflowNodeExecution:
"""Create a node execution from an event."""
now = datetime.now(UTC).replace(tzinfo=None)
now = naive_utc_now()
created_at = created_at or now
metadata = {
@ -403,7 +403,7 @@ class WorkflowCycleManager:
handle_special_values: bool = False,
) -> None:
"""Update node execution with completion data."""
finished_at = datetime.now(UTC).replace(tzinfo=None)
finished_at = naive_utc_now()
elapsed_time = (finished_at - event.start_at).total_seconds()
# Process data

View File

@ -1,4 +1,3 @@
import datetime
import logging
import time
@ -8,6 +7,7 @@ from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from events.event_handlers.document_index_event import document_index_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Document
@ -33,7 +33,7 @@ def handle(sender, **kwargs):
raise NotFound("Document not found")
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from datetime import UTC, datetime, timedelta
from datetime import timedelta
from typing import Optional
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
@ -8,6 +8,7 @@ from azure.storage.blob import AccountSasPermissions, BlobServiceClient, Resourc
from configs import dify_config
from extensions.ext_redis import redis_client
from extensions.storage.base_storage import BaseStorage
from libs.datetime_utils import naive_utc_now
class AzureBlobStorage(BaseStorage):
@ -78,7 +79,7 @@ class AzureBlobStorage(BaseStorage):
account_key=self.account_key or "",
resource_types=ResourceTypes(service=True, container=True, object=True),
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
expiry=naive_utc_now() + timedelta(hours=1),
)
redis_client.set(cache_key, sas_token, ex=3000)
return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)

View File

@ -149,9 +149,7 @@ def _build_from_local_file(
if strict_type_validation and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
return File(
id=mapping.get("id"),
@ -200,9 +198,7 @@ def _build_from_remote_url(
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type)
if specified_type and specified_type != FileType.CUSTOM.value
else detected_file_type
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
)
return File(
@ -287,9 +283,7 @@ def _build_from_tool_file(
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
return File(
id=mapping.get("id"),

View File

@ -1,4 +1,3 @@
import datetime
import urllib.parse
from typing import Any
@ -6,6 +5,7 @@ import requests
from flask_login import current_user
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
@ -75,7 +75,7 @@ class NotionOAuth(OAuthDataSource):
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
@ -115,7 +115,7 @@ class NotionOAuth(OAuthDataSource):
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
@ -154,7 +154,7 @@ class NotionOAuth(OAuthDataSource):
}
data_source_binding.source_info = new_source_info
data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
raise ValueError("Data source binding not found")

View File

@ -0,0 +1,51 @@
"""update models
Revision ID: 1a83934ad6d1
Revises: 71f5020c6470
Create Date: 2025-07-21 09:35:48.774794
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1a83934ad6d1'
down_revision = '71f5020c6470'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.alter_column('server_identifier',
existing_type=sa.VARCHAR(length=24),
type_=sa.String(length=64),
existing_nullable=False)
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
batch_op.alter_column('tool_name',
existing_type=sa.VARCHAR(length=40),
type_=sa.String(length=128),
existing_nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
batch_op.alter_column('tool_name',
existing_type=sa.String(length=128),
type_=sa.VARCHAR(length=40),
existing_nullable=False)
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.alter_column('server_identifier',
existing_type=sa.String(length=64),
type_=sa.VARCHAR(length=24),
existing_nullable=False)
# ### end Alembic commands ###

View File

@ -287,7 +287,7 @@ class Dataset(Base):
@staticmethod
def gen_collection_name_by_id(dataset_id: str) -> str:
normalized_dataset_id = dataset_id.replace("-", "_")
return f"Vector_index_{normalized_dataset_id}_Node"
return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node"
class DatasetProcessRule(Base):

View File

@ -1,7 +1,6 @@
from datetime import UTC, datetime
from celery import states # type: ignore
from libs.datetime_utils import naive_utc_now
from models.base import Base
from .engine import db
@ -18,8 +17,8 @@ class CeleryTask(Base):
result = db.Column(db.PickleType, nullable=True)
date_done = db.Column(
db.DateTime,
default=lambda: datetime.now(UTC).replace(tzinfo=None),
onupdate=lambda: datetime.now(UTC).replace(tzinfo=None),
default=lambda: naive_utc_now(),
onupdate=lambda: naive_utc_now(),
nullable=True,
)
traceback = db.Column(db.Text, nullable=True)
@ -39,4 +38,4 @@ class CeleryTaskSet(Base):
id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True)
taskset_id = db.Column(db.String(155), unique=True)
result = db.Column(db.PickleType, nullable=True)
date_done = db.Column(db.DateTime, default=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True)
date_done = db.Column(db.DateTime, default=lambda: naive_utc_now(), nullable=True)

View File

@ -253,7 +253,7 @@ class MCPToolProvider(Base):
# name of the mcp provider
name: Mapped[str] = mapped_column(db.String(40), nullable=False)
# server identifier of the mcp provider
server_identifier: Mapped[str] = mapped_column(db.String(24), nullable=False)
server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False)
# encrypted url of the mcp provider
server_url: Mapped[str] = mapped_column(db.Text, nullable=False)
# hash of server_url for uniqueness check
@ -357,7 +357,7 @@ class ToolModelInvoke(Base):
# type
tool_type = db.Column(db.String(40), nullable=False)
# tool name
tool_name = db.Column(db.String(40), nullable=False)
tool_name = db.Column(db.String(128), nullable=False)
# invoke parameters
model_parameters = db.Column(db.Text, nullable=False)
# prompt messages

View File

@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from datetime import datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4
@ -16,6 +16,7 @@ from core.variables.variables import FloatVariable, IntegerVariable, StringVaria
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.enums import NodeType
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
@ -139,7 +140,7 @@ class Workflow(Base):
updated_at: Mapped[datetime] = mapped_column(
db.DateTime,
nullable=False,
default=datetime.now(UTC).replace(tzinfo=None),
default=naive_utc_now(),
server_onupdate=func.current_timestamp(),
)
_environment_variables: Mapped[str] = mapped_column(
@ -185,7 +186,7 @@ class Workflow(Base):
workflow.rag_pipeline_variables = rag_pipeline_variables or []
workflow.marked_name = marked_name
workflow.marked_comment = marked_comment
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow.created_at = naive_utc_now()
workflow.updated_at = workflow.created_at
return workflow
@ -938,7 +939,7 @@ _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
def _naive_utc_datetime():
return datetime.now(UTC).replace(tzinfo=None)
return naive_utc_now()
class WorkflowDraftVariable(Base):

View File

@ -17,6 +17,7 @@ from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
from libs.datetime_utils import naive_utc_now
from libs.helper import RateLimiter, TokenManager
from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password
@ -135,8 +136,8 @@ class AccountService:
available_ta.current = True
db.session.commit()
if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10):
account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
if naive_utc_now() - account.last_active_at > timedelta(minutes=10):
account.last_active_at = naive_utc_now()
db.session.commit()
return cast(Account, account)
@ -180,7 +181,7 @@ class AccountService:
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
account.initialized_at = naive_utc_now()
db.session.commit()
@ -318,7 +319,7 @@ class AccountService:
# If it exists, update the record
account_integrate.open_id = open_id
account_integrate.encrypted_token = "" # todo
account_integrate.updated_at = datetime.now(UTC).replace(tzinfo=None)
account_integrate.updated_at = naive_utc_now()
else:
# If it does not exist, create a new record
account_integrate = AccountIntegrate(
@ -353,7 +354,7 @@ class AccountService:
@staticmethod
def update_login_info(account: Account, *, ip_address: str) -> None:
"""Update last login time and ip"""
account.last_login_at = datetime.now(UTC).replace(tzinfo=None)
account.last_login_at = naive_utc_now()
account.last_login_ip = ip_address
db.session.add(account)
db.session.commit()
@ -1066,15 +1067,6 @@ class TenantService:
target_member_join.role = new_role
db.session.commit()
@staticmethod
def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
"""Dissolve tenant"""
if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
raise NoPermissionError("No permission to dissolve tenant.")
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
db.session.delete(tenant)
db.session.commit()
@staticmethod
def get_custom_config(tenant_id: str) -> dict:
tenant = db.get_or_404(Tenant, tenant_id)
@ -1117,7 +1109,7 @@ class RegisterService:
)
account.last_login_ip = ip_address
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
account.initialized_at = naive_utc_now()
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
@ -1158,7 +1150,7 @@ class RegisterService:
is_setup=is_setup,
)
account.status = AccountStatus.ACTIVE.value if not status else status.value
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
account.initialized_at = naive_utc_now()
if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)

View File

@ -1,6 +1,5 @@
import json
import logging
from datetime import UTC, datetime
from typing import Optional, cast
from flask_login import current_user
@ -17,6 +16,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.account import Account
from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider
@ -235,7 +235,7 @@ class AppService:
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)
app.max_active_requests = args.get("max_active_requests")
app.updated_by = current_user.id
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
app.updated_at = naive_utc_now()
db.session.commit()
return app
@ -249,7 +249,7 @@ class AppService:
"""
app.name = name
app.updated_by = current_user.id
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
app.updated_at = naive_utc_now()
db.session.commit()
return app
@ -265,7 +265,7 @@ class AppService:
app.icon = icon
app.icon_background = icon_background
app.updated_by = current_user.id
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
app.updated_at = naive_utc_now()
db.session.commit()
return app
@ -282,7 +282,7 @@ class AppService:
app.enable_site = enable_site
app.updated_by = current_user.id
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
app.updated_at = naive_utc_now()
db.session.commit()
return app
@ -299,7 +299,7 @@ class AppService:
app.enable_api = enable_api
app.updated_by = current_user.id
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
app.updated_at = naive_utc_now()
db.session.commit()
return app

View File

@ -1,5 +1,4 @@
from collections.abc import Callable, Sequence
from datetime import UTC, datetime
from typing import Optional, Union
from sqlalchemy import asc, desc, func, or_, select
@ -8,6 +7,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import ConversationVariable
from models.account import Account
@ -113,7 +113,7 @@ class ConversationService:
return cls.auto_generate_name(app_model, conversation)
else:
conversation.name = name
conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
conversation.updated_at = naive_utc_now()
db.session.commit()
return conversation
@ -169,7 +169,7 @@ class ConversationService:
conversation = cls.get_conversation(app_model, conversation_id, user)
conversation.is_deleted = True
conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
conversation.updated_at = naive_utc_now()
db.session.commit()
@classmethod

View File

@ -26,6 +26,7 @@ from events.document_event import document_was_deleted
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from libs.datetime_utils import naive_utc_now
from models.account import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@ -484,7 +485,7 @@ class DatasetService:
# Add metadata fields
filtered_data["updated_by"] = user.id
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
filtered_data["updated_at"] = naive_utc_now()
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
# update icon info
@ -1175,7 +1176,7 @@ class DocumentService:
# update document to be paused
document.is_paused = True
document.paused_by = current_user.id
document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.paused_at = naive_utc_now()
db.session.add(document)
db.session.commit()

View File

@ -1,6 +1,5 @@
import json
from copy import deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from urllib.parse import urlparse
@ -11,6 +10,7 @@ from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
from core.rag.entities.metadata_entities import MetadataCondition
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import (
Dataset,
ExternalKnowledgeApis,
@ -120,7 +120,7 @@ class ExternalDatasetService:
external_knowledge_api.description = args.get("description", "")
external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False)
external_knowledge_api.updated_by = user_id
external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None)
external_knowledge_api.updated_at = naive_utc_now()
db.session.commit()
return external_knowledge_api

View File

@ -70,16 +70,15 @@ class MCPToolManageService:
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
if existing_provider:
if existing_provider.name == name:
raise ValueError(f"MCP tool {name} already exists")
elif existing_provider.server_url_hash == server_url_hash:
if existing_provider.server_url_hash == server_url_hash:
raise ValueError(f"MCP tool {server_url} already exists")
elif existing_provider.server_identifier == server_identifier:
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
mcp_tool = MCPToolProvider(
@ -111,15 +110,14 @@ class MCPToolManageService:
]
@classmethod
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
try:
with MCPClient(
mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True
) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError as e:
except MCPAuthError:
raise ValueError("Please auth the tool first")
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
@ -184,12 +182,11 @@ class MCPToolManageService:
error_msg = str(e.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
elif "unique_mcp_provider_server_url" in error_msg:
if "unique_mcp_provider_server_url" in error_msg:
raise ValueError(f"MCP tool {server_url} already exists")
elif "unique_mcp_provider_server_identifier" in error_msg:
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
else:
raise
raise
@classmethod
def update_mcp_provider_credentials(

View File

@ -2,7 +2,6 @@ import json
import time
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, cast
from uuid import uuid4
@ -33,6 +32,7 @@ from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from libs.datetime_utils import naive_utc_now
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
@ -232,7 +232,7 @@ class WorkflowService:
workflow.graph = json.dumps(graph)
workflow.features = json.dumps(features)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.updated_at = naive_utc_now()
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
@ -268,7 +268,7 @@ class WorkflowService:
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type=draft_workflow.type,
version=Workflow.version_from_datetime(datetime.now(UTC).replace(tzinfo=None)),
version=Workflow.version_from_datetime(naive_utc_now()),
graph=draft_workflow.graph,
created_by=account.id,
environment_variables=draft_workflow.environment_variables,
@ -524,8 +524,8 @@ class WorkflowService:
node_type=node.type_,
title=node.title,
elapsed_time=time.perf_counter() - start_at,
created_at=datetime.now(UTC).replace(tzinfo=None),
finished_at=datetime.now(UTC).replace(tzinfo=None),
created_at=naive_utc_now(),
finished_at=naive_utc_now(),
)
if run_succeeded and node_run_result:
@ -622,7 +622,7 @@ class WorkflowService:
setattr(workflow, field, value)
workflow.updated_by = account_id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.updated_at = naive_utc_now()
return workflow

View File

@ -1,4 +1,3 @@
import datetime
import logging
import time
@ -8,6 +7,7 @@ from celery import shared_task # type: ignore
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
@ -53,7 +53,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
db.session.close()
@ -68,7 +68,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
if document:
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()

View File

@ -26,8 +26,15 @@ redis_mock.hgetall = MagicMock(return_value={})
redis_mock.hdel = MagicMock()
redis_mock.incr = MagicMock(return_value=1)
# Add the API directory to Python path to ensure proper imports
import sys
sys.path.insert(0, PROJECT_DIR)
# apply the mock to the Redis client in the Flask app
redis_patcher = patch("extensions.ext_redis.redis_client", redis_mock)
from extensions import ext_redis
redis_patcher = patch.object(ext_redis, "redis_client", redis_mock)
redis_patcher.start()

View File

@ -0,0 +1 @@
# API authentication service test module

View File

@ -0,0 +1,49 @@
import pytest
from services.auth.api_key_auth_base import ApiKeyAuthBase
class ConcreteApiKeyAuth(ApiKeyAuthBase):
"""Concrete implementation for testing abstract base class"""
def validate_credentials(self):
return True
class TestApiKeyAuthBase:
def test_should_store_credentials_on_init(self):
"""Test that credentials are properly stored during initialization"""
credentials = {"api_key": "test_key", "auth_type": "bearer"}
auth = ConcreteApiKeyAuth(credentials)
assert auth.credentials == credentials
def test_should_not_instantiate_abstract_class(self):
"""Test that ApiKeyAuthBase cannot be instantiated directly"""
credentials = {"api_key": "test_key"}
with pytest.raises(TypeError) as exc_info:
ApiKeyAuthBase(credentials)
assert "Can't instantiate abstract class" in str(exc_info.value)
assert "validate_credentials" in str(exc_info.value)
def test_should_allow_subclass_implementation(self):
"""Test that subclasses can properly implement the abstract method"""
credentials = {"api_key": "test_key", "auth_type": "bearer"}
auth = ConcreteApiKeyAuth(credentials)
# Should not raise any exception
result = auth.validate_credentials()
assert result is True
def test_should_handle_empty_credentials(self):
"""Test initialization with empty credentials"""
credentials = {}
auth = ConcreteApiKeyAuth(credentials)
assert auth.credentials == {}
def test_should_handle_none_credentials(self):
"""Test initialization with None credentials"""
credentials = None
auth = ConcreteApiKeyAuth(credentials)
assert auth.credentials is None

View File

@ -0,0 +1,81 @@
from unittest.mock import MagicMock, patch
import pytest
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.auth_type import AuthType
class TestApiKeyAuthFactory:
"""Test cases for ApiKeyAuthFactory"""
@pytest.mark.parametrize(
("provider", "auth_class_path"),
[
(AuthType.FIRECRAWL, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
(AuthType.WATERCRAWL, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
(AuthType.JINA, "services.auth.jina.jina.JinaAuth"),
],
)
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):
"""Test getting auth factory for all valid providers"""
with patch(auth_class_path) as mock_auth:
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
assert auth_class == mock_auth
@pytest.mark.parametrize(
"invalid_provider",
[
"invalid_provider",
"",
None,
123,
"UNSUPPORTED",
],
)
def test_get_apikey_auth_factory_invalid_providers(self, invalid_provider):
"""Test getting auth factory with various invalid providers"""
with pytest.raises(ValueError) as exc_info:
ApiKeyAuthFactory.get_apikey_auth_factory(invalid_provider)
assert str(exc_info.value) == "Invalid provider"
@pytest.mark.parametrize(
("credentials_return_value", "expected_result"),
[
(True, True),
(False, False),
],
)
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
def test_validate_credentials_delegates_to_auth_instance(
self, mock_get_factory, credentials_return_value, expected_result
):
"""Test that validate_credentials delegates to auth instance correctly"""
# Arrange
mock_auth_instance = MagicMock()
mock_auth_instance.validate_credentials.return_value = credentials_return_value
mock_auth_class = MagicMock(return_value=mock_auth_instance)
mock_get_factory.return_value = mock_auth_class
# Act
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
result = factory.validate_credentials()
# Assert
assert result is expected_result
mock_auth_instance.validate_credentials.assert_called_once()
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
def test_validate_credentials_propagates_exceptions(self, mock_get_factory):
"""Test that exceptions from auth instance are propagated"""
# Arrange
mock_auth_instance = MagicMock()
mock_auth_instance.validate_credentials.side_effect = Exception("Authentication error")
mock_auth_class = MagicMock(return_value=mock_auth_instance)
mock_get_factory.return_value = mock_auth_class
# Act & Assert
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
with pytest.raises(Exception) as exc_info:
factory.validate_credentials()
assert str(exc_info.value) == "Authentication error"

View File

@ -0,0 +1,382 @@
import json
from unittest.mock import Mock, patch
import pytest
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_service import ApiKeyAuthService
class TestApiKeyAuthService:
"""API key authentication service security tests"""
def setup_method(self):
"""Setup test fixtures"""
self.tenant_id = "test_tenant_123"
self.category = "search"
self.provider = "google"
self.binding_id = "binding_123"
self.mock_credentials = {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}}
self.mock_args = {"category": self.category, "provider": self.provider, "credentials": self.mock_credentials}
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_success(self, mock_session):
"""Test get provider auth list - success scenario"""
# Mock database query result
mock_binding = Mock()
mock_binding.tenant_id = self.tenant_id
mock_binding.provider = self.provider
mock_binding.disabled = False
mock_session.query.return_value.filter.return_value.all.return_value = [mock_binding]
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
assert len(result) == 1
assert result[0].tenant_id == self.tenant_id
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_empty(self, mock_session):
"""Test get provider auth list - empty result"""
mock_session.query.return_value.filter.return_value.all.return_value = []
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
assert result == []
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_filters_disabled(self, mock_session):
"""Test get provider auth list - filters disabled items"""
mock_session.query.return_value.filter.return_value.all.return_value = []
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
# Verify filter conditions include disabled.is_(False)
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 2 # tenant_id and disabled filter conditions
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_success(self, mock_encrypter, mock_factory, mock_session):
"""Test create provider auth - success scenario"""
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
# Mock encryption
encrypted_key = "encrypted_test_key_123"
mock_encrypter.encrypt_token.return_value = encrypted_key
# Mock database operations
mock_session.add = Mock()
mock_session.commit = Mock()
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
# Verify factory class calls
mock_factory.assert_called_once_with(self.provider, self.mock_credentials)
mock_auth_instance.validate_credentials.assert_called_once()
# Verify encryption calls
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, "test_secret_key_123")
# Verify database operations
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
def test_create_provider_auth_validation_failed(self, mock_factory, mock_session):
"""Test create provider auth - validation failed"""
# Mock failed auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = False
mock_factory.return_value = mock_auth_instance
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
# Verify no database operations when validation fails
mock_session.add.assert_not_called()
mock_session.commit.assert_not_called()
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_encrypts_api_key(self, mock_encrypter, mock_factory, mock_session):
"""Test create provider auth - ensures API key is encrypted"""
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
# Mock encryption
encrypted_key = "encrypted_test_key_123"
mock_encrypter.encrypt_token.return_value = encrypted_key
# Mock database operations
mock_session.add = Mock()
mock_session.commit = Mock()
args_copy = self.mock_args.copy()
original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
# Verify original key is replaced with encrypted key
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore
assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore
# Verify encryption function is called correctly
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_success(self, mock_session):
"""Test get auth credentials - success scenario"""
# Mock database query result
mock_binding = Mock()
mock_binding.credentials = json.dumps(self.mock_credentials)
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
assert result == self.mock_credentials
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_not_found(self, mock_session):
"""Test get auth credentials - not found"""
mock_session.query.return_value.filter.return_value.first.return_value = None
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
assert result is None
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_filters_correctly(self, mock_session):
"""Test get auth credentials - applies correct filters"""
mock_session.query.return_value.filter.return_value.first.return_value = None
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
# Verify filter conditions are correct
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 4 # tenant_id, category, provider, disabled
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_json_parsing(self, mock_session):
"""Test get auth credentials - JSON parsing"""
# Mock credentials with special characters
special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}}
mock_binding = Mock()
mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False)
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
assert result == special_credentials
assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
@patch("services.auth.api_key_auth_service.db.session")
def test_delete_provider_auth_success(self, mock_session):
"""Test delete provider auth - success scenario"""
# Mock database query result
mock_binding = Mock()
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
# Verify delete operations
mock_session.delete.assert_called_once_with(mock_binding)
mock_session.commit.assert_called_once()
@patch("services.auth.api_key_auth_service.db.session")
def test_delete_provider_auth_not_found(self, mock_session):
"""Test delete provider auth - not found"""
mock_session.query.return_value.filter.return_value.first.return_value = None
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
# Verify no delete operations when not found
mock_session.delete.assert_not_called()
mock_session.commit.assert_not_called()
@patch("services.auth.api_key_auth_service.db.session")
def test_delete_provider_auth_filters_by_tenant(self, mock_session):
"""Test delete provider auth - filters by tenant"""
mock_session.query.return_value.filter.return_value.first.return_value = None
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
# Verify filter conditions include tenant_id and binding_id
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 2
def test_validate_api_key_auth_args_success(self):
"""Test API key auth args validation - success scenario"""
# Should not raise any exception
ApiKeyAuthService.validate_api_key_auth_args(self.mock_args)
def test_validate_api_key_auth_args_missing_category(self):
"""Test API key auth args validation - missing category"""
args = self.mock_args.copy()
del args["category"]
with pytest.raises(ValueError, match="category is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_category(self):
"""Test API key auth args validation - empty category"""
args = self.mock_args.copy()
args["category"] = ""
with pytest.raises(ValueError, match="category is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_missing_provider(self):
"""Test API key auth args validation - missing provider"""
args = self.mock_args.copy()
del args["provider"]
with pytest.raises(ValueError, match="provider is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_provider(self):
"""Test API key auth args validation - empty provider"""
args = self.mock_args.copy()
args["provider"] = ""
with pytest.raises(ValueError, match="provider is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_missing_credentials(self):
"""Test API key auth args validation - missing credentials"""
args = self.mock_args.copy()
del args["credentials"]
with pytest.raises(ValueError, match="credentials is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_credentials(self):
"""Test API key auth args validation - empty credentials"""
args = self.mock_args.copy()
args["credentials"] = None # type: ignore
with pytest.raises(ValueError, match="credentials is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_invalid_credentials_type(self):
"""Test API key auth args validation - invalid credentials type"""
args = self.mock_args.copy()
args["credentials"] = "not_a_dict"
with pytest.raises(ValueError, match="credentials must be a dictionary"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_missing_auth_type(self):
"""Test API key auth args validation - missing auth_type"""
args = self.mock_args.copy()
del args["credentials"]["auth_type"] # type: ignore
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_auth_type(self):
"""Test API key auth args validation - empty auth_type"""
args = self.mock_args.copy()
args["credentials"]["auth_type"] = "" # type: ignore
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
@pytest.mark.parametrize(
"malicious_input",
[
"<script>alert('xss')</script>",
"'; DROP TABLE users; --",
"../../../etc/passwd",
"\\x00\\x00", # null bytes
"A" * 10000, # very long input
],
)
def test_validate_api_key_auth_args_malicious_input(self, malicious_input):
"""Test API key auth args validation - malicious input"""
args = self.mock_args.copy()
args["category"] = malicious_input
# Verify parameter validator doesn't crash on malicious input
# Should validate normally rather than raising security-related exceptions
ApiKeyAuthService.validate_api_key_auth_args(args)
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_database_error_handling(self, mock_encrypter, mock_factory, mock_session):
"""Test create provider auth - database error handling"""
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
# Mock encryption
mock_encrypter.encrypt_token.return_value = "encrypted_key"
# Mock database error
mock_session.commit.side_effect = Exception("Database error")
with pytest.raises(Exception, match="Database error"):
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_invalid_json(self, mock_session):
"""Test get auth credentials - invalid JSON"""
# Mock database returning invalid JSON
mock_binding = Mock()
mock_binding.credentials = "invalid json content"
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
with pytest.raises(json.JSONDecodeError):
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
def test_create_provider_auth_factory_exception(self, mock_factory, mock_session):
"""Test create provider auth - factory exception"""
# Mock factory raising exception
mock_factory.side_effect = Exception("Factory error")
with pytest.raises(Exception, match="Factory error"):
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, mock_session):
"""Test create provider auth - encryption exception"""
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
# Mock encryption exception
mock_encrypter.encrypt_token.side_effect = Exception("Encryption error")
with pytest.raises(Exception, match="Encryption error"):
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
def test_validate_api_key_auth_args_none_input(self):
"""Test API key auth args validation - None input"""
with pytest.raises(TypeError):
ApiKeyAuthService.validate_api_key_auth_args(None)
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
"""Test API key auth args validation - dict credentials with list auth_type"""
args = self.mock_args.copy()
args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
# So this should not raise exception, this test should pass
ApiKeyAuthService.validate_api_key_auth_args(args)

View File

@ -0,0 +1,191 @@
from unittest.mock import MagicMock, patch
import pytest
import requests
from services.auth.firecrawl.firecrawl import FirecrawlAuth
class TestFirecrawlAuth:
@pytest.fixture
def valid_credentials(self):
"""Fixture for valid bearer credentials"""
return {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
@pytest.fixture
def auth_instance(self, valid_credentials):
"""Fixture for FirecrawlAuth instance with valid credentials"""
return FirecrawlAuth(valid_credentials)
def test_should_initialize_with_valid_bearer_credentials(self, valid_credentials):
"""Test successful initialization with valid bearer credentials"""
auth = FirecrawlAuth(valid_credentials)
assert auth.api_key == "test_api_key_123"
assert auth.base_url == "https://api.firecrawl.dev"
assert auth.credentials == valid_credentials
def test_should_initialize_with_custom_base_url(self):
"""Test initialization with custom base URL"""
credentials = {
"auth_type": "bearer",
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
}
auth = FirecrawlAuth(credentials)
assert auth.api_key == "test_api_key_123"
assert auth.base_url == "https://custom.firecrawl.dev"
@pytest.mark.parametrize(
("auth_type", "expected_error"),
[
("basic", "Invalid auth type, Firecrawl auth type must be Bearer"),
("x-api-key", "Invalid auth type, Firecrawl auth type must be Bearer"),
("", "Invalid auth type, Firecrawl auth type must be Bearer"),
],
)
def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error):
"""Test that non-bearer auth types raise ValueError"""
credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}}
with pytest.raises(ValueError) as exc_info:
FirecrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@pytest.mark.parametrize(
("credentials", "expected_error"),
[
({"auth_type": "bearer", "config": {}}, "No API key provided"),
({"auth_type": "bearer"}, "No API key provided"),
({"auth_type": "bearer", "config": {"api_key": ""}}, "No API key provided"),
({"auth_type": "bearer", "config": {"api_key": None}}, "No API key provided"),
],
)
def test_should_raise_error_for_missing_api_key(self, credentials, expected_error):
"""Test that missing or empty API key raises ValueError"""
with pytest.raises(ValueError) as exc_info:
FirecrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
result = auth_instance.validate_credentials()
assert result is True
expected_data = {
"url": "https://example.com",
"includePaths": [],
"excludePaths": [],
"limit": 1,
"scrapeOptions": {"onlyMainContent": True},
}
mock_post.assert_called_once_with(
"https://api.firecrawl.dev/v1/crawl",
headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"},
json=expected_data,
)
@pytest.mark.parametrize(
("status_code", "error_message"),
[
(402, "Payment required"),
(409, "Conflict error"),
(500, "Internal server error"),
],
)
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.json.return_value = {"error": error_message}
mock_post.return_value = mock_response
with pytest.raises(Exception) as exc_info:
auth_instance.validate_credentials()
assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}"
@pytest.mark.parametrize(
("status_code", "response_text", "has_json_error", "expected_error_contains"),
[
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
(401, "Not JSON", True, "Expecting value"), # JSON decode error
],
)
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_handle_unexpected_errors(
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
"""Test handling of unexpected errors with various response formats"""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.text = response_text
if has_json_error:
mock_response.json.side_effect = Exception("Not JSON")
mock_post.return_value = mock_response
with pytest.raises(Exception) as exc_info:
auth_instance.validate_credentials()
assert expected_error_contains in str(exc_info.value)
@pytest.mark.parametrize(
("exception_type", "exception_message"),
[
(requests.ConnectionError, "Network error"),
(requests.Timeout, "Request timeout"),
(requests.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_post.side_effect = exception_type(exception_message)
with pytest.raises(exception_type) as exc_info:
auth_instance.validate_credentials()
assert exception_message in str(exc_info.value)
def test_should_not_expose_api_key_in_error_messages(self):
"""Test that API key is not exposed in error messages"""
credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}
auth = FirecrawlAuth(credentials)
# Verify API key is stored but not in any error message
assert auth.api_key == "super_secret_key_12345"
# Test various error scenarios don't expose the key
with pytest.raises(ValueError) as exc_info:
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_use_custom_base_url_in_validation(self, mock_post):
"""Test that custom base URL is used in validation"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
credentials = {
"auth_type": "bearer",
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
}
auth = FirecrawlAuth(credentials)
result = auth.validate_credentials()
assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds")
with pytest.raises(requests.Timeout) as exc_info:
auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message
assert "timed out" in str(exc_info.value)

View File

@ -0,0 +1,155 @@
from unittest.mock import MagicMock, patch
import pytest
import requests
from services.auth.jina.jina import JinaAuth
class TestJinaAuth:
def test_should_initialize_with_valid_bearer_credentials(self):
"""Test successful initialization with valid bearer credentials"""
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
assert auth.api_key == "test_api_key_123"
assert auth.credentials == credentials
def test_should_raise_error_for_invalid_auth_type(self):
"""Test that non-bearer auth type raises ValueError"""
credentials = {"auth_type": "basic", "config": {"api_key": "test_api_key_123"}}
with pytest.raises(ValueError) as exc_info:
JinaAuth(credentials)
assert str(exc_info.value) == "Invalid auth type, Jina Reader auth type must be Bearer"
def test_should_raise_error_for_missing_api_key(self):
"""Test that missing API key raises ValueError"""
credentials = {"auth_type": "bearer", "config": {}}
with pytest.raises(ValueError) as exc_info:
JinaAuth(credentials)
assert str(exc_info.value) == "No API key provided"
def test_should_raise_error_for_missing_config(self):
"""Test that missing config section raises ValueError"""
credentials = {"auth_type": "bearer"}
with pytest.raises(ValueError) as exc_info:
JinaAuth(credentials)
assert str(exc_info.value) == "No API key provided"
@patch("services.auth.jina.jina.requests.post")
def test_should_validate_valid_credentials_successfully(self, mock_post):
"""Test successful credential validation"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
result = auth.validate_credentials()
assert result is True
mock_post.assert_called_once_with(
"https://r.jina.ai",
headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"},
json={"url": "https://example.com"},
)
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_http_402_error(self, mock_post):
"""Test handling of 402 Payment Required error"""
mock_response = MagicMock()
mock_response.status_code = 402
mock_response.json.return_value = {"error": "Payment required"}
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(Exception) as exc_info:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_http_409_error(self, mock_post):
"""Test handling of 409 Conflict error"""
mock_response = MagicMock()
mock_response.status_code = 409
mock_response.json.return_value = {"error": "Conflict error"}
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(Exception) as exc_info:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_http_500_error(self, mock_post):
"""Test handling of 500 Internal Server Error"""
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.json.return_value = {"error": "Internal server error"}
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(Exception) as exc_info:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
"""Test handling of unexpected errors with text response"""
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.text = '{"error": "Forbidden"}'
mock_response.json.side_effect = Exception("Not JSON")
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(Exception) as exc_info:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_unexpected_error_without_text(self, mock_post):
"""Test handling of unexpected errors without text response"""
mock_response = MagicMock()
mock_response.status_code = 404
mock_response.text = ""
mock_response.json.side_effect = Exception("Not JSON")
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(Exception) as exc_info:
auth.validate_credentials()
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_network_errors(self, mock_post):
"""Test handling of network connection errors"""
mock_post.side_effect = requests.ConnectionError("Network error")
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(requests.ConnectionError):
auth.validate_credentials()
def test_should_not_expose_api_key_in_error_messages(self):
"""Test that API key is not exposed in error messages"""
credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}
auth = JinaAuth(credentials)
# Verify API key is stored but not in any error message
assert auth.api_key == "super_secret_key_12345"
# Test various error scenarios don't expose the key
with pytest.raises(ValueError) as exc_info:
JinaAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)

View File

@ -0,0 +1,205 @@
from unittest.mock import MagicMock, patch
import pytest
import requests
from services.auth.watercrawl.watercrawl import WatercrawlAuth
class TestWatercrawlAuth:
@pytest.fixture
def valid_credentials(self):
"""Fixture for valid x-api-key credentials"""
return {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123"}}
@pytest.fixture
def auth_instance(self, valid_credentials):
"""Fixture for WatercrawlAuth instance with valid credentials"""
return WatercrawlAuth(valid_credentials)
def test_should_initialize_with_valid_x_api_key_credentials(self, valid_credentials):
"""Test successful initialization with valid x-api-key credentials"""
auth = WatercrawlAuth(valid_credentials)
assert auth.api_key == "test_api_key_123"
assert auth.base_url == "https://app.watercrawl.dev"
assert auth.credentials == valid_credentials
def test_should_initialize_with_custom_base_url(self):
"""Test initialization with custom base URL"""
credentials = {
"auth_type": "x-api-key",
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"},
}
auth = WatercrawlAuth(credentials)
assert auth.api_key == "test_api_key_123"
assert auth.base_url == "https://custom.watercrawl.dev"
@pytest.mark.parametrize(
("auth_type", "expected_error"),
[
("bearer", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
("basic", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
("", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
],
)
def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error):
"""Test that non-x-api-key auth types raise ValueError"""
credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}}
with pytest.raises(ValueError) as exc_info:
WatercrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@pytest.mark.parametrize(
("credentials", "expected_error"),
[
({"auth_type": "x-api-key", "config": {}}, "No API key provided"),
({"auth_type": "x-api-key"}, "No API key provided"),
({"auth_type": "x-api-key", "config": {"api_key": ""}}, "No API key provided"),
({"auth_type": "x-api-key", "config": {"api_key": None}}, "No API key provided"),
],
)
def test_should_raise_error_for_missing_api_key(self, credentials, expected_error):
"""Test that missing or empty API key raises ValueError"""
with pytest.raises(ValueError) as exc_info:
WatercrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_get.return_value = mock_response
result = auth_instance.validate_credentials()
assert result is True
mock_get.assert_called_once_with(
"https://app.watercrawl.dev/api/v1/core/crawl-requests/",
headers={"Content-Type": "application/json", "X-API-KEY": "test_api_key_123"},
)
@pytest.mark.parametrize(
("status_code", "error_message"),
[
(402, "Payment required"),
(409, "Conflict error"),
(500, "Internal server error"),
],
)
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.json.return_value = {"error": error_message}
mock_get.return_value = mock_response
with pytest.raises(Exception) as exc_info:
auth_instance.validate_credentials()
assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}"
@pytest.mark.parametrize(
("status_code", "response_text", "has_json_error", "expected_error_contains"),
[
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
(401, "Not JSON", True, "Expecting value"), # JSON decode error
],
)
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_handle_unexpected_errors(
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
"""Test handling of unexpected errors with various response formats"""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.text = response_text
if has_json_error:
mock_response.json.side_effect = Exception("Not JSON")
mock_get.return_value = mock_response
with pytest.raises(Exception) as exc_info:
auth_instance.validate_credentials()
assert expected_error_contains in str(exc_info.value)
@pytest.mark.parametrize(
("exception_type", "exception_message"),
[
(requests.ConnectionError, "Network error"),
(requests.Timeout, "Request timeout"),
(requests.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_get.side_effect = exception_type(exception_message)
with pytest.raises(exception_type) as exc_info:
auth_instance.validate_credentials()
assert exception_message in str(exc_info.value)
def test_should_not_expose_api_key_in_error_messages(self):
"""Test that API key is not exposed in error messages"""
credentials = {"auth_type": "x-api-key", "config": {"api_key": "super_secret_key_12345"}}
auth = WatercrawlAuth(credentials)
# Verify API key is stored but not in any error message
assert auth.api_key == "super_secret_key_12345"
# Test various error scenarios don't expose the key
with pytest.raises(ValueError) as exc_info:
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_use_custom_base_url_in_validation(self, mock_get):
"""Test that custom base URL is used in validation"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_get.return_value = mock_response
credentials = {
"auth_type": "x-api-key",
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"},
}
auth = WatercrawlAuth(credentials)
result = auth.validate_credentials()
assert result is True
assert mock_get.call_args[0][0] == "https://custom.watercrawl.dev/api/v1/core/crawl-requests/"
@pytest.mark.parametrize(
("base_url", "expected_url"),
[
("https://app.watercrawl.dev", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
("https://app.watercrawl.dev/", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
],
)
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
"""Test that urljoin is used correctly for URL construction with various base URLs"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_get.return_value = mock_response
credentials = {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123", "base_url": base_url}}
auth = WatercrawlAuth(credentials)
auth.validate_credentials()
# Verify the correct URL was called
assert mock_get.call_args[0][0] == expected_url
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds")
with pytest.raises(requests.Timeout) as exc_info:
auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message
assert "timed out" in str(exc_info.value)

View File

@ -102,17 +102,16 @@ class TestDatasetServiceUpdateDataset:
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
patch("extensions.ext_database.db.session") as mock_db,
patch("services.dataset_service.datetime") as mock_datetime,
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = current_time
mock_datetime.UTC = datetime.UTC
mock_naive_utc_now.return_value = current_time
yield {
"get_dataset": mock_get_dataset,
"check_permission": mock_check_perm,
"db_session": mock_db,
"datetime": mock_datetime,
"naive_utc_now": mock_naive_utc_now,
"current_time": current_time,
}
@ -292,7 +291,7 @@ class TestDatasetServiceUpdateDataset:
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
@ -327,7 +326,7 @@ class TestDatasetServiceUpdateDataset:
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
"updated_at": mock_dataset_service_dependencies["current_time"],
}
actual_call_args = mock_dataset_service_dependencies[
@ -365,7 +364,7 @@ class TestDatasetServiceUpdateDataset:
"collection_binding_id": None,
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
@ -422,7 +421,7 @@ class TestDatasetServiceUpdateDataset:
"collection_binding_id": "binding-456",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
@ -463,7 +462,7 @@ class TestDatasetServiceUpdateDataset:
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
@ -525,7 +524,7 @@ class TestDatasetServiceUpdateDataset:
"collection_binding_id": "binding-789",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
@ -568,7 +567,7 @@ class TestDatasetServiceUpdateDataset:
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(

View File

@ -283,11 +283,12 @@ REDIS_CLUSTERS_PASSWORD=
# Celery Configuration
# ------------------------------
# Use redis as the broker, and redis db 1 for celery broker.
# Format as follows: `redis://<redis_username>:<redis_password>@<redis_host>:<redis_port>/<redis_database>`
# Use standalone redis as the broker, and redis db 1 for celery broker. (redis_username is usually set by defualt as empty)
# Format as follows: `redis://<redis_username>:<redis_password>@<redis_host>:<redis_port>/<redis_database>`.
# Example: redis://:difyai123456@redis:6379/1
# If use Redis Sentinel, format as follows: `sentinel://<sentinel_username>:<sentinel_password>@<sentinel_host>:<sentinel_port>/<redis_database>`
# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1
# If use Redis Sentinel, format as follows: `sentinel://<redis_username>:<redis_password>@<sentinel_host1>:<sentinel_port>/<redis_database>`
# For high availability, you can configure multiple Sentinel nodes (if provided) separated by semicolons like below example:
# Example: sentinel://:difyai123456@localhost:26379/1;sentinel://:difyai12345@localhost:26379/1;sentinel://:difyai12345@localhost:26379/1
CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1
CELERY_BACKEND=redis
BROKER_USE_SSL=false
@ -412,6 +413,8 @@ SUPABASE_URL=your-server-url
# The type of vector store to use.
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
WEAVIATE_ENDPOINT=http://weaviate:8080

View File

@ -136,6 +136,7 @@ x-shared-env: &shared-api-worker-env
SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key}
SUPABASE_URL: ${SUPABASE_URL:-your-server-url}
VECTOR_STORE: ${VECTOR_STORE:-weaviate}
VECTOR_INDEX_NAME_PREFIX: ${VECTOR_INDEX_NAME_PREFIX:-Vector_index}
WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080}
WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih}
QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333}

View File

@ -1 +1,7 @@
from dify_client.client import ChatClient, CompletionClient, WorkflowClient, KnowledgeBaseClient, DifyClient
from dify_client.client import (
ChatClient,
CompletionClient,
WorkflowClient,
KnowledgeBaseClient,
DifyClient,
)

View File

@ -6,7 +6,7 @@ LABEL maintainer="takatost@gmail.com"
# RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories
RUN apk add --no-cache tzdata
RUN npm install -g pnpm@10.11.1
RUN npm install -g pnpm@10.13.1
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"

View File

@ -0,0 +1,41 @@
'use client'
import { useTranslation } from 'react-i18next'
import {
RiAddLine,
RiArrowRightLine,
} from '@remixicon/react'
import Link from 'next/link'
type CreateAppCardProps = {
ref?: React.Ref<HTMLAnchorElement>
}
const CreateAppCard = ({ ref }: CreateAppCardProps) => {
const { t } = useTranslation()
return (
<div className='bg-background-default-dimm flex min-h-[160px] flex-col rounded-xl border-[0.5px]
border-components-panel-border transition-all duration-200 ease-in-out'
>
<Link ref={ref} className='group flex grow cursor-pointer items-start p-4' href='/datasets/create'>
<div className='flex items-center gap-3'>
<div className='flex h-10 w-10 items-center justify-center rounded-lg border border-dashed border-divider-regular bg-background-default-lighter
p-2 group-hover:border-solid group-hover:border-effects-highlight group-hover:bg-background-default-dodge'
>
<RiAddLine className='h-4 w-4 text-text-tertiary group-hover:text-text-accent' />
</div>
<div className='system-md-semibold text-text-secondary group-hover:text-text-accent'>{t('dataset.createDataset')}</div>
</div>
</Link>
<div className='system-xs-regular p-4 pt-0 text-text-tertiary'>{t('dataset.createDatasetIntro')}</div>
<Link className='group flex cursor-pointer items-center gap-1 rounded-b-xl border-t-[0.5px] border-divider-subtle p-4' href='/datasets/connect'>
<div className='system-xs-medium text-text-tertiary group-hover:text-text-accent'>{t('dataset.connectDataset')}</div>
<RiArrowRightLine className='h-3.5 w-3.5 text-text-tertiary group-hover:text-text-accent' />
</Link>
</div>
)
}
CreateAppCard.displayName = 'CreateAppCard'
export default CreateAppCard

View File

@ -1,6 +1,6 @@
import React from 'react'
import type { ReactNode } from 'react'
import SwrInitor from '@/app/components/swr-initor'
import SwrInitializer from '@/app/components/swr-initializer'
import { AppContextProvider } from '@/context/app-context'
import GA, { GaType } from '@/app/components/base/ga'
import HeaderWrapper from '@/app/components/header/header-wrapper'
@ -13,7 +13,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
return (
<>
<GA gaType={GaType.admin} />
<SwrInitor>
<SwrInitializer>
<AppContextProvider>
<EventEmitterContextProvider>
<ProviderContextProvider>
@ -26,7 +26,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
</ProviderContextProvider>
</EventEmitterContextProvider>
</AppContextProvider>
</SwrInitor>
</SwrInitializer>
</>
)
}

View File

@ -1,5 +1,6 @@
'use client'
import { useState } from 'react'
import useSWR from 'swr'
import { useTranslation } from 'react-i18next'
import {
RiGraduationCapFill,
@ -22,6 +23,8 @@ import PremiumBadge from '@/app/components/base/premium-badge'
import { useGlobalPublicStore } from '@/context/global-public-context'
import EmailChangeModal from './email-change-modal'
import { validPassword } from '@/config'
import { fetchAppList } from '@/service/apps'
import type { App } from '@/types/app'
const titleClassName = `
system-sm-semibold text-text-secondary
@ -33,7 +36,9 @@ const descriptionClassName = `
export default function AccountPage() {
const { t } = useTranslation()
const { systemFeatures } = useGlobalPublicStore()
const { mutateUserProfile, userProfile, apps } = useAppContext()
const { data: appList } = useSWR({ url: '/apps', params: { page: 1, limit: 100, name: '' } }, fetchAppList)
const apps = appList?.data || []
const { mutateUserProfile, userProfile } = useAppContext()
const { isEducationAccount } = useProviderContext()
const { notify } = useContext(ToastContext)
const [editNameModalVisible, setEditNameModalVisible] = useState(false)
@ -202,7 +207,7 @@ export default function AccountPage() {
{!!apps.length && (
<Collapse
title={`${t('common.account.showAppLength', { length: apps.length })}`}
items={apps.map(app => ({ ...app, key: app.id, name: app.name }))}
items={apps.map((app: App) => ({ ...app, key: app.id, name: app.name }))}
renderItem={renderAppItem}
wrapperClassName='mt-2'
/>

View File

@ -1,7 +1,7 @@
import React from 'react'
import type { ReactNode } from 'react'
import Header from './header'
import SwrInitor from '@/app/components/swr-initor'
import SwrInitor from '@/app/components/swr-initializer'
import { AppContextProvider } from '@/context/app-context'
import GA, { GaType } from '@/app/components/base/ga'
import HeaderWrapper from '@/app/components/header/header-wrapper'

View File

@ -1,6 +1,6 @@
import { useTranslation } from 'react-i18next'
import { useRouter } from 'next/navigation'
import { useContext, useContextSelector } from 'use-context-selector'
import { useContext } from 'use-context-selector'
import React, { useCallback, useState } from 'react'
import {
RiDeleteBinLine,
@ -15,7 +15,7 @@ import AppIcon from '../base/app-icon'
import cn from '@/utils/classnames'
import { useStore as useAppStore } from '@/app/components/app/store'
import { ToastContext } from '@/app/components/base/toast'
import AppsContext, { useAppContext } from '@/context/app-context'
import { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps'
import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal'
@ -73,11 +73,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
const [showImportDSLModal, setShowImportDSLModal] = useState<boolean>(false)
const [secretEnvList, setSecretEnvList] = useState<EnvironmentVariable[]>([])
const mutateApps = useContextSelector(
AppsContext,
state => state.mutateApps,
)
const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({
name,
icon_type,
@ -106,12 +101,11 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
message: t('app.editDone'),
})
setAppDetail(app)
mutateApps()
}
catch {
notify({ type: 'error', message: t('app.editFailed') })
}
}, [appDetail, mutateApps, notify, setAppDetail, t])
}, [appDetail, notify, setAppDetail, t])
const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => {
if (!appDetail)
@ -131,7 +125,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
message: t('app.newApp.appCreated'),
})
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
mutateApps()
onPlanInfoChanged()
getRedirection(true, newApp, replace)
}
@ -186,7 +179,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
try {
await deleteApp(appDetail.id)
notify({ type: 'success', message: t('app.appDeleted') })
mutateApps()
onPlanInfoChanged()
setAppDetail()
replace('/apps')
@ -198,7 +190,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
})
}
setShowConfirmDelete(false)
}, [appDetail, mutateApps, notify, onPlanInfoChanged, replace, setAppDetail, t])
}, [appDetail, notify, onPlanInfoChanged, replace, setAppDetail, t])
const { isCurrentWorkspaceEditor } = useAppContext()

View File

@ -13,7 +13,6 @@ import Loading from '@/app/components/base/loading'
import Badge from '@/app/components/base/badge'
import { useKnowledge } from '@/hooks/use-knowledge'
import cn from '@/utils/classnames'
import { basePath } from '@/utils/var'
import AppIcon from '@/app/components/base/app-icon'
export type ISelectDataSetProps = {
@ -113,7 +112,7 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
}}
>
<span className='text-text-tertiary'>{t('appDebug.feature.dataSet.noDataSet')}</span>
<Link href={`${basePath}/datasets/create`} className='font-normal text-text-accent'>{t('appDebug.feature.dataSet.toCreate')}</Link>
<Link href='/datasets/create' className='font-normal text-text-accent'>{t('appDebug.feature.dataSet.toCreate')}</Link>
</div>
)}

View File

@ -4,7 +4,7 @@ import { useCallback, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useRouter } from 'next/navigation'
import { useContext, useContextSelector } from 'use-context-selector'
import { useContext } from 'use-context-selector'
import { RiArrowRightLine, RiArrowRightSLine, RiCommandLine, RiCornerDownLeftLine, RiExchange2Fill } from '@remixicon/react'
import Link from 'next/link'
import { useDebounceFn, useKeyPress } from 'ahooks'
@ -15,7 +15,7 @@ import Button from '@/app/components/base/button'
import Divider from '@/app/components/base/divider'
import cn from '@/utils/classnames'
import { basePath } from '@/utils/var'
import AppsContext, { useAppContext } from '@/context/app-context'
import { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
import { ToastContext } from '@/app/components/base/toast'
import type { AppMode } from '@/types/app'
@ -41,7 +41,6 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps)
const { t } = useTranslation()
const { push } = useRouter()
const { notify } = useContext(ToastContext)
const mutateApps = useContextSelector(AppsContext, state => state.mutateApps)
const [appMode, setAppMode] = useState<AppMode>('advanced-chat')
const [appIcon, setAppIcon] = useState<AppIconSelection>({ type: 'emoji', icon: '🤖', background: '#FFEAD5' })
@ -80,7 +79,6 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps)
notify({ type: 'success', message: t('app.newApp.appCreated') })
onSuccess()
onClose()
mutateApps()
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
getRedirection(isCurrentWorkspaceEditor, app, push)
}
@ -88,7 +86,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps)
notify({ type: 'error', message: t('app.newApp.appCreateFailed') })
}
isCreatingRef.current = false
}, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, mutateApps, push, isCurrentWorkspaceEditor])
}, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, push, isCurrentWorkspaceEditor])
const { run: handleCreateApp } = useDebounceFn(onCreate, { wait: 300 })
useKeyPress(['meta.enter', 'ctrl.enter'], () => {
@ -298,7 +296,7 @@ function AppTypeCard({ icon, title, description, active, onClick }: AppTypeCardP
>
{icon}
<div className='system-sm-semibold mb-0.5 mt-2 text-text-secondary'>{title}</div>
<div className='system-xs-regular text-text-tertiary'>{description}</div>
<div className='system-xs-regular line-clamp-2 text-text-tertiary' title={description}>{description}</div>
</div>
}

View File

@ -90,10 +90,10 @@ const Embedded = ({ siteInfo, isShow, onClose, appBaseUrl, accessToken, classNam
const [option, setOption] = useState<Option>('iframe')
const [isCopied, setIsCopied] = useState<OptionStatus>({ iframe: false, scripts: false, chromePlugin: false })
const { langeniusVersionInfo } = useAppContext()
const { langGeniusVersionInfo } = useAppContext()
const themeBuilder = useThemeContext()
themeBuilder.buildTheme(siteInfo?.chat_color_theme ?? null, siteInfo?.chat_color_theme_inverted ?? false)
const isTestEnv = langeniusVersionInfo.current_env === 'TESTING' || langeniusVersionInfo.current_env === 'DEVELOPMENT'
const isTestEnv = langGeniusVersionInfo.current_env === 'TESTING' || langGeniusVersionInfo.current_env === 'DEVELOPMENT'
const onClickCopy = () => {
if (option === 'chromePlugin') {
const splitUrl = OPTION_MAP[option].getContent(appBaseUrl, accessToken).split(': ')

View File

@ -1,7 +1,7 @@
'use client'
import React, { useCallback, useEffect, useMemo, useState } from 'react'
import { useContext, useContextSelector } from 'use-context-selector'
import { useContext } from 'use-context-selector'
import { useRouter } from 'next/navigation'
import { useTranslation } from 'react-i18next'
import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLine } from '@remixicon/react'
@ -11,7 +11,7 @@ import Toast, { ToastContext } from '@/app/components/base/toast'
import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps'
import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal'
import AppIcon from '@/app/components/base/app-icon'
import AppsContext, { useAppContext } from '@/context/app-context'
import { useAppContext } from '@/context/app-context'
import type { HtmlContentProps } from '@/app/components/base/popover'
import CustomPopover from '@/app/components/base/popover'
import Divider from '@/app/components/base/divider'
@ -65,11 +65,6 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
const { onPlanInfoChanged } = useProviderContext()
const { push } = useRouter()
const mutateApps = useContextSelector(
AppsContext,
state => state.mutateApps,
)
const [showEditModal, setShowEditModal] = useState(false)
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
const [showSwitchModal, setShowSwitchModal] = useState<boolean>(false)
@ -83,7 +78,6 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
notify({ type: 'success', message: t('app.appDeleted') })
if (onRefresh)
onRefresh()
mutateApps()
onPlanInfoChanged()
}
catch (e: any) {
@ -93,7 +87,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
})
}
setShowConfirmDelete(false)
}, [app.id, mutateApps, notify, onPlanInfoChanged, onRefresh, t])
}, [app.id, notify, onPlanInfoChanged, onRefresh, t])
const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({
name,
@ -122,12 +116,11 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
})
if (onRefresh)
onRefresh()
mutateApps()
}
catch {
notify({ type: 'error', message: t('app.editFailed') })
}
}, [app.id, mutateApps, notify, onRefresh, t])
}, [app.id, notify, onRefresh, t])
const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => {
try {
@ -147,7 +140,6 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
if (onRefresh)
onRefresh()
mutateApps()
onPlanInfoChanged()
getRedirection(isCurrentWorkspaceEditor, newApp, push)
}
@ -195,16 +187,14 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
const onSwitch = () => {
if (onRefresh)
onRefresh()
mutateApps()
setShowSwitchModal(false)
}
const onUpdateAccessControl = useCallback(() => {
if (onRefresh)
onRefresh()
mutateApps()
setShowAccessControl(false)
}, [onRefresh, mutateApps, setShowAccessControl])
}, [onRefresh, setShowAccessControl])
const Operations = (props: HtmlContentProps) => {
const { data: userCanAccessApp, isLoading: isGettingUserCanAccessApp } = useGetUserCanAccessApp({ appId: app?.id, enabled: (!!props?.open && systemFeatures.webapp_auth.enabled) })
@ -325,7 +315,6 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
dateFormat: `${t('datasetDocuments.segment.dateTimeFormat')}`,
})
return `${t('datasetDocuments.segment.editedAt')} ${timeText}`
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [app.updated_at, app.created_at])
return (

View File

@ -45,7 +45,7 @@ const Avatar = ({
className={cn(textClassName, 'scale-[0.4] text-center text-white')}
style={style}
>
{name[0].toLocaleUpperCase()}
{name && name[0].toLocaleUpperCase()}
</div>
</div>
)

View File

@ -117,7 +117,7 @@ const Question: FC<QuestionProps> = ({
</div>
<div
ref={contentRef}
className='bg-background-gradient-bg-fill-chat-bubble-bg-3 w-full rounded-2xl px-4 py-3 text-sm text-text-primary'
className='w-full rounded-2xl bg-background-gradient-bg-fill-chat-bubble-bg-3 px-4 py-3 text-sm text-text-primary'
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
>
{

View File

@ -21,7 +21,7 @@ const AppsFull: FC<{ loc: string; className?: string; }> = ({
}) => {
const { t } = useTranslation()
const { plan } = useProviderContext()
const { userProfile, langeniusVersionInfo } = useAppContext()
const { userProfile, langGeniusVersionInfo } = useAppContext()
const isTeam = plan.type === Plan.team
const usage = plan.usage.buildApps
const total = plan.total.buildApps
@ -62,7 +62,7 @@ const AppsFull: FC<{ loc: string; className?: string; }> = ({
)}
{plan.type !== Plan.sandbox && plan.type !== Plan.professional && (
<Button variant='secondary-accent'>
<a target='_blank' rel='noopener noreferrer' href={mailToSupport(userProfile.email, plan.type, langeniusVersionInfo.current_version)}>
<a target='_blank' rel='noopener noreferrer' href={mailToSupport(userProfile.email, plan.type, langGeniusVersionInfo.current_version)}>
{t('billing.apps.contactUs')}
</a>
</Button>

View File

@ -43,10 +43,10 @@ Object.defineProperty(globalThis, 'sessionStorage', {
value: sessionStorage,
})
const BrowserInitor = ({
const BrowserInitializer = ({
children,
}: { children: React.ReactNode }) => {
}: { children: React.ReactElement }) => {
return children
}
export default BrowserInitor
export default BrowserInitializer

View File

@ -83,7 +83,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>subchunk_segmentation</code> (object) 子チャンクルール
- <code>separator</code> セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは <code>***</code>
- <code>max_tokens</code> 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります
- <code>chunk_overlap</code> 隣接するチャンク間の重を定義 (オプション)
- <code>chunk_overlap</code> 隣接するチャンク間の重なりを定義 (オプション)
</Property>
<PropertyInstruction>ナレッジベースにパラメータが設定されていない場合、最初のアップロードには以下のパラメータを提供する必要があります。提供されない場合、デフォルトパラメータが使用されます。</PropertyInstruction>
<Property name='retrieval_model' type='object' key='retrieval_model'>
@ -218,7 +218,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>subchunk_segmentation</code> (object) 子チャンクルール
- <code>separator</code> セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは <code>***</code>
- <code>max_tokens</code> 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります
- <code>chunk_overlap</code> 隣接するチャンク間の重を定義 (オプション)
- <code>chunk_overlap</code> 隣接するチャンク間の重なりを定義 (オプション)
</Property>
<Property name='file' type='multipart/form-data' key='file'>
アップロードする必要があるファイル。
@ -555,7 +555,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>subchunk_segmentation</code> (object) 子チャンクルール
- <code>separator</code> セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは <code>***</code>
- <code>max_tokens</code> 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります
- <code>chunk_overlap</code> 隣接するチャンク間の重を定義 (オプション)
- <code>chunk_overlap</code> 隣接するチャンク間の重なりを定義 (オプション)
</Property>
</Properties>
</Col>
@ -657,7 +657,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>subchunk_segmentation</code> (object) 子チャンクルール
- <code>separator</code> セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは <code>***</code>
- <code>max_tokens</code> 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります
- <code>chunk_overlap</code> 隣接するチャンク間の重を定義 (オプション)
- <code>chunk_overlap</code> 隣接するチャンク間の重なりを定義 (オプション)
</Property>
</Properties>
</Col>

View File

@ -333,7 +333,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等
<Col>
根据 workflow 执行 ID 获取 workflow 任务当前执行结果
### Path
- `workflow_run_id` (string) workflow_run_id,可在流式返回 Chunk 中获取
- `workflow_run_id` (string) workflow 执行 ID,可在流式返回 Chunk 中获取
### Response
- `id` (string) workflow 执行 ID
- `workflow_id` (string) 关联的 Workflow ID

View File

@ -12,16 +12,16 @@ import { noop } from 'lodash-es'
import { useGlobalPublicStore } from '@/context/global-public-context'
type IAccountSettingProps = {
langeniusVersionInfo: LangGeniusVersionResponse
langGeniusVersionInfo: LangGeniusVersionResponse
onCancel: () => void
}
export default function AccountAbout({
langeniusVersionInfo,
langGeniusVersionInfo,
onCancel,
}: IAccountSettingProps) {
const { t } = useTranslation()
const isLatest = langeniusVersionInfo.current_version === langeniusVersionInfo.latest_version
const isLatest = langGeniusVersionInfo.current_version === langGeniusVersionInfo.latest_version
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
return (
@ -43,7 +43,7 @@ export default function AccountAbout({
/>
: <DifyLogo size='large' className='mx-auto' />}
<div className='text-center text-xs font-normal text-text-tertiary'>Version {langeniusVersionInfo?.current_version}</div>
<div className='text-center text-xs font-normal text-text-tertiary'>Version {langGeniusVersionInfo?.current_version}</div>
<div className='flex flex-col items-center gap-2 text-center text-xs font-normal text-text-secondary'>
<div>© {dayjs().year()} LangGenius, Inc., Contributors.</div>
<div className='text-text-accent'>
@ -63,8 +63,8 @@ export default function AccountAbout({
<div className='text-xs font-medium text-text-tertiary'>
{
isLatest
? t('common.about.latestAvailable', { version: langeniusVersionInfo.latest_version })
: t('common.about.nowAvailable', { version: langeniusVersionInfo.latest_version })
? t('common.about.latestAvailable', { version: langGeniusVersionInfo.latest_version })
: t('common.about.nowAvailable', { version: langGeniusVersionInfo.latest_version })
}
</div>
<div className='flex items-center'>
@ -80,7 +80,7 @@ export default function AccountAbout({
!isLatest && !IS_CE_EDITION && (
<Button variant='primary' size='small'>
<Link
href={langeniusVersionInfo.release_notes}
href={langGeniusVersionInfo.release_notes}
target='_blank' rel='noopener noreferrer'
>
{t('common.about.updateNow')}

View File

@ -45,7 +45,7 @@ export default function AppSelector() {
const { t } = useTranslation()
const docLink = useDocLink()
const { userProfile, langeniusVersionInfo, isCurrentWorkspaceOwner } = useAppContext()
const { userProfile, langGeniusVersionInfo, isCurrentWorkspaceOwner } = useAppContext()
const { isEducationAccount } = useProviderContext()
const { setShowAccountSettingModal } = useModalContext()
@ -180,8 +180,8 @@ export default function AppSelector() {
<RiInformation2Line className='size-4 shrink-0 text-text-tertiary' />
<div className='system-md-regular grow px-1 text-text-secondary'>{t('common.userProfile.about')}</div>
<div className='flex shrink-0 items-center'>
<div className='system-xs-regular mr-2 text-text-tertiary'>{langeniusVersionInfo.current_version}</div>
<Indicator color={langeniusVersionInfo.current_version === langeniusVersionInfo.latest_version ? 'green' : 'orange'} />
<div className='system-xs-regular mr-2 text-text-tertiary'>{langGeniusVersionInfo.current_version}</div>
<Indicator color={langGeniusVersionInfo.current_version === langGeniusVersionInfo.latest_version ? 'green' : 'orange'} />
</div>
</div>
</MenuItem>
@ -217,7 +217,7 @@ export default function AppSelector() {
}
</Menu>
{
aboutVisible && <AccountAbout onCancel={() => setAboutVisible(false)} langeniusVersionInfo={langeniusVersionInfo} />
aboutVisible && <AccountAbout onCancel={() => setAboutVisible(false)} langGeniusVersionInfo={langGeniusVersionInfo} />
}
</div >
)

View File

@ -16,7 +16,7 @@ export default function Support() {
`
const { t } = useTranslation()
const { plan } = useProviderContext()
const { userProfile, langeniusVersionInfo } = useAppContext()
const { userProfile, langGeniusVersionInfo } = useAppContext()
const canEmailSupport = plan.type === Plan.professional || plan.type === Plan.team || plan.type === Plan.enterprise
return <Menu as="div" className="relative h-full w-full">
@ -53,7 +53,7 @@ export default function Support() {
className={cn(itemClassName, 'group justify-between',
'data-[active]:bg-state-base-hover',
)}
href={mailToSupport(userProfile.email, plan.type, langeniusVersionInfo.current_version)}
href={mailToSupport(userProfile.email, plan.type, langGeniusVersionInfo.current_version)}
target='_blank' rel='noopener noreferrer'>
<RiMailSendLine className='size-4 shrink-0 text-text-tertiary' />
<div className='system-md-regular grow px-1 text-text-secondary'>{t('common.userProfile.emailSupport')}</div>

View File

@ -12,8 +12,8 @@ const headerEnvClassName: { [k: string]: string } = {
const EnvNav = () => {
const { t } = useTranslation()
const { langeniusVersionInfo } = useAppContext()
const showEnvTag = langeniusVersionInfo.current_env === 'TESTING' || langeniusVersionInfo.current_env === 'DEVELOPMENT'
const { langGeniusVersionInfo } = useAppContext()
const showEnvTag = langGeniusVersionInfo.current_env === 'TESTING' || langGeniusVersionInfo.current_env === 'DEVELOPMENT'
if (!showEnvTag)
return null
@ -21,10 +21,10 @@ const EnvNav = () => {
return (
<div className={`
mr-1 flex h-[22px] items-center rounded-md border px-2 text-xs font-medium
${headerEnvClassName[langeniusVersionInfo.current_env]}
${headerEnvClassName[langGeniusVersionInfo.current_env]}
`}>
{
langeniusVersionInfo.current_env === 'TESTING' && (
langGeniusVersionInfo.current_env === 'TESTING' && (
<>
<Beaker02 className='h-3 w-3' />
<div className='ml-1 max-[1280px]:hidden'>{t('common.environment.testing')}</div>
@ -32,7 +32,7 @@ const EnvNav = () => {
)
}
{
langeniusVersionInfo.current_env === 'DEVELOPMENT' && (
langGeniusVersionInfo.current_env === 'DEVELOPMENT' && (
<>
<TerminalSquare className='h-3 w-3' />
<div className='ml-1 max-[1280px]:hidden'>{t('common.environment.development')}</div>

View File

@ -48,7 +48,6 @@ const Installed: FC<Props> = ({
useEffect(() => {
if (hasInstalled && uniqueIdentifier === installedInfoPayload.uniqueIdentifier)
onInstalled()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [hasInstalled])
const [isInstalling, setIsInstalling] = React.useState(false)
@ -105,12 +104,12 @@ const Installed: FC<Props> = ({
}
}
const { langeniusVersionInfo } = useAppContext()
const { langGeniusVersionInfo } = useAppContext()
const isDifyVersionCompatible = useMemo(() => {
if (!langeniusVersionInfo.current_version)
if (!langGeniusVersionInfo.current_version)
return true
return gte(langeniusVersionInfo.current_version, payload.meta.minimum_dify_version ?? '0.0.0')
}, [langeniusVersionInfo.current_version, payload.meta.minimum_dify_version])
return gte(langGeniusVersionInfo.current_version, payload.meta.minimum_dify_version ?? '0.0.0')
}, [langGeniusVersionInfo.current_version, payload.meta.minimum_dify_version])
return (
<>

View File

@ -59,7 +59,6 @@ const Installed: FC<Props> = ({
useEffect(() => {
if (hasInstalled && uniqueIdentifier === installedInfoPayload.uniqueIdentifier)
onInstalled()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [hasInstalled])
const handleCancel = () => {
@ -120,12 +119,12 @@ const Installed: FC<Props> = ({
}
}
const { langeniusVersionInfo } = useAppContext()
const { langGeniusVersionInfo } = useAppContext()
const { data: pluginDeclaration } = usePluginDeclarationFromMarketPlace(uniqueIdentifier)
const isDifyVersionCompatible = useMemo(() => {
if (!pluginDeclaration || !langeniusVersionInfo.current_version) return true
return gte(langeniusVersionInfo.current_version, pluginDeclaration?.manifest.meta.minimum_dify_version ?? '0.0.0')
}, [langeniusVersionInfo.current_version, pluginDeclaration])
if (!pluginDeclaration || !langGeniusVersionInfo.current_version) return true
return gte(langGeniusVersionInfo.current_version, pluginDeclaration?.manifest.meta.minimum_dify_version ?? '0.0.0')
}, [langGeniusVersionInfo.current_version, pluginDeclaration])
const { canInstall } = useInstallPluginLimit({ ...payload, from: 'marketplace' })
return (

View File

@ -265,7 +265,7 @@ const ToolSelector: FC<Props> = ({
/>
)}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent>
<PortalToFollowElemContent className='z-10'>
<div className={cn('relative max-h-[642px] min-h-20 w-[361px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur pb-4 shadow-lg backdrop-blur-sm', 'overflow-y-auto pb-2')}>
<>
<div className='system-xl-semibold px-4 pb-1 pt-3.5 text-text-primary'>{t(`plugin.detailPanel.toolSelector.${isEdit ? 'toolSetting' : 'title'}`)}</div>
@ -309,15 +309,15 @@ const ToolSelector: FC<Props> = ({
{currentProvider && currentProvider.type === CollectionType.builtIn && currentProvider.allow_delete && (
<>
<Divider className='my-1 w-full' />
<div className='px-4 py-2'>
<PluginAuthInAgent
pluginPayload={{
provider: currentProvider.name,
category: AuthCategory.tool,
}}
credentialId={value?.credential_id}
onAuthorizationItemClick={handleAuthorizationItemClick}
/>
<div className='px-4 py-2'>
<PluginAuthInAgent
pluginPayload={{
provider: currentProvider.name,
category: AuthCategory.tool,
}}
credentialId={value?.credential_id}
onAuthorizationItemClick={handleAuthorizationItemClick}
/>
</div>
</>
)}

View File

@ -62,13 +62,13 @@ const PluginItem: FC<Props> = ({
return [PluginSource.github, PluginSource.marketplace].includes(source) ? author : ''
}, [source, author])
const { langeniusVersionInfo } = useAppContext()
const { langGeniusVersionInfo } = useAppContext()
const isDifyVersionCompatible = useMemo(() => {
if (!langeniusVersionInfo.current_version)
if (!langGeniusVersionInfo.current_version)
return true
return gte(langeniusVersionInfo.current_version, declarationMeta.minimum_dify_version ?? '0.0.0')
}, [declarationMeta.minimum_dify_version, langeniusVersionInfo.current_version])
return gte(langGeniusVersionInfo.current_version, declarationMeta.minimum_dify_version ?? '0.0.0')
}, [declarationMeta.minimum_dify_version, langGeniusVersionInfo.current_version])
const handleDelete = () => {
refreshPluginList({ category } as any)

View File

@ -5,9 +5,9 @@ import * as Sentry from '@sentry/react'
const isDevelopment = process.env.NODE_ENV === 'development'
const SentryInit = ({
const SentryInitializer = ({
children,
}: { children: React.ReactNode }) => {
}: { children: React.ReactElement }) => {
useEffect(() => {
const SENTRY_DSN = document?.body?.getAttribute('data-public-sentry-dsn')
if (!isDevelopment && SENTRY_DSN) {
@ -26,4 +26,4 @@ const SentryInit = ({
return children
}
export default SentryInit
export default SentryInitializer

View File

@ -10,12 +10,12 @@ import {
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
} from '@/app/education-apply/constants'
type SwrInitorProps = {
type SwrInitializerProps = {
children: ReactNode
}
const SwrInitor = ({
const SwrInitializer = ({
children,
}: SwrInitorProps) => {
}: SwrInitializerProps) => {
const router = useRouter()
const searchParams = useSearchParams()
const consoleToken = decodeURIComponent(searchParams.get('access_token') || '')
@ -86,4 +86,4 @@ const SwrInitor = ({
: null
}
export default SwrInitor
export default SwrInitializer

View File

@ -697,7 +697,7 @@ const getIterationItemType = ({
case VarType.arrayObject:
return VarType.object
case VarType.array:
return VarType.any
return VarType.arrayObject // Use more specific type instead of any
case VarType.arrayFile:
return VarType.file
default:

View File

@ -13,7 +13,7 @@ const metaData = genNodeMetaData({
const nodeDefault: NodeDefault<AgentNodeType> = {
metaData,
defaultValue: {
version: '2',
tool_node_version: '2',
},
checkValid(payload, t, moreDataForCheckValid: {
strategyProvider?: StrategyPluginDetail,
@ -58,27 +58,29 @@ const nodeDefault: NodeDefault<AgentNodeType> = {
const userSettings = toolValue.settings
const reasoningConfig = toolValue.parameters
const version = payload.version
const toolNodeVersion = payload.tool_node_version
const mergeVersion = version || toolNodeVersion
schemas.forEach((schema: any) => {
if (schema?.required) {
if (schema.form === 'form' && !version && !userSettings[schema.name]?.value) {
if (schema.form === 'form' && !mergeVersion && !userSettings[schema.name]?.value) {
return {
isValid: false,
errorMessage: t('workflow.errorMsg.toolParameterRequired', { field: renderI18nObject(param.label, language), param: renderI18nObject(schema.label, language) }),
}
}
if (schema.form === 'form' && version && !userSettings[schema.name]?.value.value) {
if (schema.form === 'form' && mergeVersion && !userSettings[schema.name]?.value.value) {
return {
isValid: false,
errorMessage: t('workflow.errorMsg.toolParameterRequired', { field: renderI18nObject(param.label, language), param: renderI18nObject(schema.label, language) }),
}
}
if (schema.form === 'llm' && !version && reasoningConfig[schema.name].auto === 0 && !reasoningConfig[schema.name]?.value) {
if (schema.form === 'llm' && !mergeVersion && reasoningConfig[schema.name].auto === 0 && !reasoningConfig[schema.name]?.value) {
return {
isValid: false,
errorMessage: t('workflow.errorMsg.toolParameterRequired', { field: renderI18nObject(param.label, language), param: renderI18nObject(schema.label, language) }),
}
}
if (schema.form === 'llm' && version && reasoningConfig[schema.name].auto === 0 && !reasoningConfig[schema.name]?.value.value) {
if (schema.form === 'llm' && mergeVersion && reasoningConfig[schema.name].auto === 0 && !reasoningConfig[schema.name]?.value.value) {
return {
isValid: false,
errorMessage: t('workflow.errorMsg.toolParameterRequired', { field: renderI18nObject(param.label, language), param: renderI18nObject(schema.label, language) }),

View File

@ -12,6 +12,7 @@ export type AgentNodeType = CommonNodeType & {
plugin_unique_identifier?: string
memory?: Memory
version?: string
tool_node_version?: string
}
export enum AgentFeature {

View File

@ -129,7 +129,7 @@ const useConfig = (id: string, payload: AgentNodeType) => {
}
const formattingLegacyData = () => {
if (inputs.version)
if (inputs.version || inputs.tool_node_version)
return inputs
const newData = produce(inputs, (draft) => {
const schemas = currentStrategy?.parameters || []
@ -140,7 +140,7 @@ const useConfig = (id: string, payload: AgentNodeType) => {
if (targetSchema?.type === FormTypeEnum.multiToolSelector)
draft.agent_parameters![key].value = draft.agent_parameters![key].value.map((tool: any) => formattingToolData(tool))
})
draft.version = '2'
draft.tool_node_version = '2'
})
return newData
}
@ -151,7 +151,6 @@ const useConfig = (id: string, payload: AgentNodeType) => {
return
const newData = formattingLegacyData()
setInputs(newData)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentStrategy])
// vars

View File

@ -16,7 +16,7 @@ const nodeDefault: NodeDefault<ToolNodeType> = {
defaultValue: {
tool_parameters: {},
tool_configurations: {},
version: '2',
tool_node_version: '2',
},
checkValid(payload: ToolNodeType, t: any, moreDataForCheckValid: any) {
const { toolInputsSchema, toolSettingSchema, language, notAuthed } = moreDataForCheckValid

View File

@ -23,4 +23,5 @@ export type ToolNodeType = CommonNodeType & {
output_schema: Record<string, any>
paramSchemas?: Record<string, any>[]
version?: string
tool_node_version?: string
}

View File

@ -286,8 +286,8 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => {
}
}
if (node.data.type === BlockEnum.Tool && !(node as Node<ToolNodeType>).data.version) {
(node as Node<ToolNodeType>).data.version = '2'
if (node.data.type === BlockEnum.Tool && !(node as Node<ToolNodeType>).data.version && !(node as Node<ToolNodeType>).data.tool_node_version) {
(node as Node<ToolNodeType>).data.tool_node_version = '2'
const toolConfigurations = (node as Node<ToolNodeType>).data.tool_configurations
if (toolConfigurations && Object.keys(toolConfigurations).length > 0) {

View File

@ -1,10 +1,10 @@
import RoutePrefixHandle from './routePrefixHandle'
import type { Viewport } from 'next'
import I18nServer from './components/i18n-server'
import BrowserInitor from './components/browser-initor'
import SentryInitor from './components/sentry-initor'
import BrowserInitializer from './components/browser-initializer'
import SentryInitializer from './components/sentry-initializer'
import { getLocaleOnServer } from '@/i18n/server'
import { TanstackQueryIniter } from '@/context/query-client'
import { TanstackQueryInitializer } from '@/context/query-client'
import { ThemeProvider } from 'next-themes'
import './styles/globals.css'
import './styles/markdown.scss'
@ -62,9 +62,9 @@ const LocaleLayout = async ({
className="color-scheme h-full select-auto"
{...datasetMap}
>
<BrowserInitor>
<SentryInitor>
<TanstackQueryIniter>
<BrowserInitializer>
<SentryInitializer>
<TanstackQueryInitializer>
<ThemeProvider
attribute='data-theme'
defaultTheme='system'
@ -77,9 +77,9 @@ const LocaleLayout = async ({
</GlobalPublicStoreProvider>
</I18nServer>
</ThemeProvider>
</TanstackQueryIniter>
</SentryInitor>
</BrowserInitor>
</TanstackQueryInitializer>
</SentryInitializer>
</BrowserInitializer>
<RoutePrefixHandle />
</body>
</html>

View File

@ -1,20 +1,15 @@
'use client'
import { createRef, useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import useSWR from 'swr'
import { createContext, useContext, useContextSelector } from 'use-context-selector'
import type { FC, ReactNode } from 'react'
import { fetchAppList } from '@/service/apps'
import Loading from '@/app/components/base/loading'
import { fetchCurrentWorkspace, fetchLanggeniusVersion, fetchUserProfile } from '@/service/common'
import type { App } from '@/types/app'
import { fetchCurrentWorkspace, fetchLangGeniusVersion, fetchUserProfile } from '@/service/common'
import type { ICurrentWorkspace, LangGeniusVersionResponse, UserProfileResponse } from '@/models/common'
import MaintenanceNotice from '@/app/components/header/maintenance-notice'
import { noop } from 'lodash-es'
export type AppContextValue = {
apps: App[]
mutateApps: VoidFunction
userProfile: UserProfileResponse
mutateUserProfile: VoidFunction
currentWorkspace: ICurrentWorkspace
@ -23,13 +18,21 @@ export type AppContextValue = {
isCurrentWorkspaceEditor: boolean
isCurrentWorkspaceDatasetOperator: boolean
mutateCurrentWorkspace: VoidFunction
pageContainerRef: React.RefObject<HTMLDivElement>
langeniusVersionInfo: LangGeniusVersionResponse
langGeniusVersionInfo: LangGeniusVersionResponse
useSelector: typeof useSelector
isLoadingCurrentWorkspace: boolean
}
const initialLangeniusVersionInfo = {
const userProfilePlaceholder = {
id: '',
name: '',
email: '',
avatar: '',
avatar_url: '',
is_password_set: false,
}
const initialLangGeniusVersionInfo = {
current_env: '',
current_version: '',
latest_version: '',
@ -50,16 +53,7 @@ const initialWorkspaceInfo: ICurrentWorkspace = {
}
const AppContext = createContext<AppContextValue>({
apps: [],
mutateApps: noop,
userProfile: {
id: '',
name: '',
email: '',
avatar: '',
avatar_url: '',
is_password_set: false,
},
userProfile: userProfilePlaceholder,
currentWorkspace: initialWorkspaceInfo,
isCurrentWorkspaceManager: false,
isCurrentWorkspaceOwner: false,
@ -67,8 +61,7 @@ const AppContext = createContext<AppContextValue>({
isCurrentWorkspaceDatasetOperator: false,
mutateUserProfile: noop,
mutateCurrentWorkspace: noop,
pageContainerRef: createRef(),
langeniusVersionInfo: initialLangeniusVersionInfo,
langGeniusVersionInfo: initialLangGeniusVersionInfo,
useSelector,
isLoadingCurrentWorkspace: false,
})
@ -82,14 +75,11 @@ export type AppContextProviderProps = {
}
export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) => {
const pageContainerRef = useRef<HTMLDivElement>(null)
const { data: appList, mutate: mutateApps } = useSWR({ url: '/apps', params: { page: 1, limit: 30, name: '' } }, fetchAppList)
const { data: userProfileResponse, mutate: mutateUserProfile } = useSWR({ url: '/account/profile', params: {} }, fetchUserProfile)
const { data: currentWorkspaceResponse, mutate: mutateCurrentWorkspace, isLoading: isLoadingCurrentWorkspace } = useSWR({ url: '/workspaces/current', params: {} }, fetchCurrentWorkspace)
const [userProfile, setUserProfile] = useState<UserProfileResponse>()
const [langeniusVersionInfo, setLangeniusVersionInfo] = useState<LangGeniusVersionResponse>(initialLangeniusVersionInfo)
const [userProfile, setUserProfile] = useState<UserProfileResponse>(userProfilePlaceholder)
const [langGeniusVersionInfo, setLangGeniusVersionInfo] = useState<LangGeniusVersionResponse>(initialLangGeniusVersionInfo)
const [currentWorkspace, setCurrentWorkspace] = useState<ICurrentWorkspace>(initialWorkspaceInfo)
const isCurrentWorkspaceManager = useMemo(() => ['owner', 'admin'].includes(currentWorkspace.role), [currentWorkspace.role])
const isCurrentWorkspaceOwner = useMemo(() => currentWorkspace.role === 'owner', [currentWorkspace.role])
@ -101,8 +91,8 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
setUserProfile(result)
const current_version = userProfileResponse.headers.get('x-version')
const current_env = process.env.NODE_ENV === 'development' ? 'DEVELOPMENT' : userProfileResponse.headers.get('x-env')
const versionData = await fetchLanggeniusVersion({ url: '/version', params: { current_version } })
setLangeniusVersionInfo({ ...versionData, current_version, latest_version: versionData.version, current_env })
const versionData = await fetchLangGeniusVersion({ url: '/version', params: { current_version } })
setLangGeniusVersionInfo({ ...versionData, current_version, latest_version: versionData.version, current_env })
}
}, [userProfileResponse])
@ -115,17 +105,11 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
setCurrentWorkspace(currentWorkspaceResponse)
}, [currentWorkspaceResponse])
if (!appList || !userProfile)
return <Loading type='app' />
return (
<AppContext.Provider value={{
apps: appList.data,
mutateApps,
userProfile,
mutateUserProfile,
pageContainerRef,
langeniusVersionInfo,
langGeniusVersionInfo,
useSelector,
currentWorkspace,
isCurrentWorkspaceManager,
@ -137,7 +121,7 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
}}>
<div className='flex h-full flex-col overflow-y-auto'>
{globalThis.document?.body?.getAttribute('data-public-maintenance-notice') && <MaintenanceNotice />}
<div ref={pageContainerRef} className='relative flex grow flex-col overflow-y-auto overflow-x-hidden bg-background-body'>
<div className='relative flex grow flex-col overflow-y-auto overflow-x-hidden bg-background-body'>
{children}
</div>
</div>

View File

@ -14,7 +14,7 @@ const client = new QueryClient({
},
})
export const TanstackQueryIniter: FC<PropsWithChildren> = (props) => {
export const TanstackQueryInitializer: FC<PropsWithChildren> = (props) => {
const { children } = props
return <QueryClientProvider client={client}>
{children}

View File

@ -8,6 +8,7 @@ import storybook from 'eslint-plugin-storybook'
import tailwind from 'eslint-plugin-tailwindcss'
import reactHooks from 'eslint-plugin-react-hooks'
import sonar from 'eslint-plugin-sonarjs'
import oxlint from 'eslint-plugin-oxlint'
// import reactRefresh from 'eslint-plugin-react-refresh'
@ -245,4 +246,5 @@ export default combine(
'tailwindcss/migration-from-tailwind-2': 'warn',
},
},
oxlint.configs['flat/recommended'],
)

View File

@ -78,7 +78,7 @@ const translation = {
optional: 'Wahlfrei',
noTemplateFound: 'Keine Vorlagen gefunden',
workflowUserDescription: 'Autonome KI-Arbeitsabläufe visuell per Drag-and-Drop erstellen.',
foundResults: '{{Anzahl}} Befund',
foundResults: '{{count}} Befund',
chatbotShortDescription: 'LLM-basierter Chatbot mit einfacher Einrichtung',
completionUserDescription: 'Erstellen Sie schnell einen KI-Assistenten für Textgenerierungsaufgaben mit einfacher Konfiguration.',
noAppsFound: 'Keine Apps gefunden',
@ -92,7 +92,7 @@ const translation = {
noTemplateFoundTip: 'Versuchen Sie, mit verschiedenen Schlüsselwörtern zu suchen.',
advancedUserDescription: 'Workflow mit Speicherfunktionen und Chatbot-Oberfläche.',
chatbotUserDescription: 'Erstellen Sie schnell einen LLM-basierten Chatbot mit einfacher Konfiguration. Sie können später zu Chatflow wechseln.',
foundResult: '{{Anzahl}} Ergebnis',
foundResult: '{{count}} Ergebnis',
agentUserDescription: 'Ein intelligenter Agent, der in der Lage ist, iteratives Denken zu führen und autonome Werkzeuge zu verwenden, um Aufgabenziele zu erreichen.',
agentShortDescription: 'Intelligenter Agent mit logischem Denken und autonomer Werkzeugnutzung',
dropDSLToCreateApp: 'Ziehen Sie die DSL-Datei hierher, um die App zu erstellen',

Some files were not shown because too many files have changed in this diff Show More