mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
4d36e784b7
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
npm add -g pnpm@10.11.1
|
||||
npm add -g pnpm@10.13.1
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
|
|
@ -12,3 +12,4 @@ echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f do
|
|||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ jobs:
|
|||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v45
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
|
|
@ -75,7 +75,7 @@ jobs:
|
|||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v45
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
|
|
@ -113,7 +113,7 @@ jobs:
|
|||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v45
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: |
|
||||
docker/generate_docker_compose
|
||||
|
|
@ -144,7 +144,7 @@ jobs:
|
|||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v45
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
|
|
@ -152,13 +152,15 @@ jobs:
|
|||
**.yml
|
||||
**Dockerfile
|
||||
dev/**
|
||||
.editorconfig
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@v7
|
||||
uses: super-linter/super-linter/slim@v8
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
DEFAULT_BRANCH: main
|
||||
DEFAULT_BRANCH: origin/main
|
||||
EDITORCONFIG_FILE_NAME: editorconfig-checker.json
|
||||
FILTER_REGEX_INCLUDE: pnpm-lock.yaml
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
IGNORE_GENERATED_FILES: true
|
||||
|
|
@ -168,16 +170,6 @@ jobs:
|
|||
# FIXME: temporarily disabled until api-docker.yaml's run script is fixed for shellcheck
|
||||
# VALIDATE_GITHUB_ACTIONS: true
|
||||
VALIDATE_DOCKERFILE_HADOLINT: true
|
||||
VALIDATE_EDITORCONFIG: true
|
||||
VALIDATE_XML: true
|
||||
VALIDATE_YAML: true
|
||||
|
||||
- name: EditorConfig checks
|
||||
uses: super-linter/super-linter/slim@v7
|
||||
env:
|
||||
DEFAULT_BRANCH: main
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
IGNORE_GENERATED_FILES: true
|
||||
IGNORE_GITIGNORED_FILES: true
|
||||
# EditorConfig validation
|
||||
VALIDATE_EDITORCONFIG: true
|
||||
EDITORCONFIG_FILE_NAME: editorconfig-checker.json
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ jobs:
|
|||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v45
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
|
|
|
|||
|
|
@ -142,8 +142,10 @@ WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
|||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
|
||||
# Vector database configuration
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
|
||||
# Weaviate configuration
|
||||
WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
|
|
|
|||
|
|
@ -85,6 +85,11 @@ class VectorStoreConfig(BaseSettings):
|
|||
default=False,
|
||||
)
|
||||
|
||||
VECTOR_INDEX_NAME_PREFIX: Optional[str] = Field(
|
||||
description="Prefix used to create collection name in vector database",
|
||||
default="Vector_index",
|
||||
)
|
||||
|
||||
|
||||
class KeywordStoreConfig(BaseSettings):
|
||||
KEYWORD_STORE: str = Field(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
|
||||
import pytz # pip install pytz
|
||||
from flask_login import current_user
|
||||
|
|
@ -19,6 +19,7 @@ from fields.conversation_fields import (
|
|||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
|
|
@ -315,7 +316,7 @@ def _get_conversation(app_model, conversation_id):
|
|||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if not conversation.read_at:
|
||||
conversation.read_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
conversation.read_at = naive_utc_now()
|
||||
conversation.read_account_id = current_user.id
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
|
@ -10,6 +8,7 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import Site
|
||||
|
||||
|
|
@ -77,7 +76,7 @@ class AppSite(Resource):
|
|||
setattr(site, attr_name, value)
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
|
@ -101,7 +100,7 @@ class AppSiteAccessTokenReset(Resource):
|
|||
|
||||
site.code = Site.generate_code(16)
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
|
|
@ -7,6 +5,7 @@ from constants.languages import supported_language
|
|||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
|
@ -65,7 +64,7 @@ class ActivateApi(Resource):
|
|||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
|
@ -13,6 +12,7 @@ from configs import dify_config
|
|||
from constants.languages import languages
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models import Account
|
||||
|
|
@ -110,7 +110,7 @@ class OAuthCallback(Resource):
|
|||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import datetime
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
|
|
@ -15,6 +14,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
|||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
|
|
@ -88,7 +88,7 @@ class DataSourceApi(Resource):
|
|||
if action == "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
|
|
@ -97,7 +97,7 @@ class DataSourceApi(Resource):
|
|||
if action == "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from datetime import UTC, datetime
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
|
|
@ -50,6 +49,7 @@ from fields.document_fields import (
|
|||
document_status_fields,
|
||||
document_with_segments_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
|
|
@ -752,7 +752,7 @@ class DocumentProcessingApi(DocumentResource):
|
|||
raise InvalidActionError("Document not in indexing state.")
|
||||
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
document.paused_at = naive_utc_now()
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -832,7 +832,7 @@ class DocumentMetadataApi(DocumentResource):
|
|||
document.doc_metadata[key] = value
|
||||
|
||||
document.doc_type = doc_type
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
document.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "message": "Document metadata updated."}, 200
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse
|
||||
|
|
@ -27,6 +26,7 @@ from core.errors.error import (
|
|||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
|
@ -51,7 +51,7 @@ class CompletionApi(InstalledAppResource):
|
|||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
installed_app.last_used_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
|
|
@ -111,7 +111,7 @@ class ChatApi(InstalledAppResource):
|
|||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
installed_app.last_used_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
|
|
@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
|||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
|
|
@ -122,7 +122,7 @@ class InstalledAppsListApi(Resource):
|
|||
tenant_id=current_tenant_id,
|
||||
app_owner_tenant_id=app.tenant_id,
|
||||
is_pinned=False,
|
||||
last_used_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
last_used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(new_installed_app)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import datetime
|
||||
|
||||
import pytz
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
|
|
@ -35,6 +33,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, email, extract_remote_ip, timezone
|
||||
from libs.login import login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
|
|
@ -80,7 +79,7 @@ class AccountInitApi(Resource):
|
|||
raise InvalidInvitationCodeError()
|
||||
|
||||
invitation_code.status = "used"
|
||||
invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
invitation_code.used_at = naive_utc_now()
|
||||
invitation_code.used_by_tenant_id = account.current_tenant_id
|
||||
invitation_code.used_by_account_id = account.id
|
||||
|
||||
|
|
@ -88,7 +87,7 @@ class AccountInitApi(Resource):
|
|||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = "active"
|
||||
account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from libs.login import login_required
|
|||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
|
@ -15,6 +15,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
|||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
|
|
@ -256,7 +257,7 @@ def validate_and_get_api_token(scope: str | None = None):
|
|||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
current_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
current_time = naive_utc_now()
|
||||
cutoff_time = current_time - timedelta(minutes=1)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
|
|
@ -25,6 +24,7 @@ from core.app.entities.task_entities import (
|
|||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
|
|
@ -184,7 +184,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
conversation.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
message = Message(
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from core.model_runtime.entities import (
|
|||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
|
|
@ -44,11 +45,44 @@ def to_prompt_message_content(
|
|||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> PromptMessageContentUnionTypes:
|
||||
"""
|
||||
Convert a file to prompt message content.
|
||||
|
||||
This function converts files to their appropriate prompt message content types.
|
||||
For supported file types (IMAGE, AUDIO, VIDEO, DOCUMENT), it creates the
|
||||
corresponding message content with proper encoding/URL.
|
||||
|
||||
For unsupported file types, instead of raising an error, it returns a
|
||||
TextPromptMessageContent with a descriptive message about the file.
|
||||
|
||||
Args:
|
||||
f: The file to convert
|
||||
image_detail_config: Optional detail configuration for image files
|
||||
|
||||
Returns:
|
||||
PromptMessageContentUnionTypes: The appropriate message content type
|
||||
|
||||
Raises:
|
||||
ValueError: If file extension or mime_type is missing
|
||||
"""
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
|
||||
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
|
||||
FileType.IMAGE: ImagePromptMessageContent,
|
||||
FileType.AUDIO: AudioPromptMessageContent,
|
||||
FileType.VIDEO: VideoPromptMessageContent,
|
||||
FileType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
# Check if file type is supported
|
||||
if f.type not in prompt_class_map:
|
||||
# For unsupported file types, return a text description
|
||||
return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]")
|
||||
|
||||
# Process supported file types
|
||||
params = {
|
||||
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
|
||||
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
|
||||
|
|
@ -58,17 +92,7 @@ def to_prompt_message_content(
|
|||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
|
||||
FileType.IMAGE: ImagePromptMessageContent,
|
||||
FileType.AUDIO: AudioPromptMessageContent,
|
||||
FileType.VIDEO: VideoPromptMessageContent,
|
||||
FileType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
try:
|
||||
return prompt_class_map[f.type].model_validate(params)
|
||||
except KeyError:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
return prompt_class_map[f.type].model_validate(params)
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from core.mcp.types import (
|
|||
OAuthTokens,
|
||||
)
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
|
||||
|
|
|
|||
|
|
@ -68,15 +68,17 @@ class MCPClient:
|
|||
}
|
||||
|
||||
parsed_url = urlparse(self.server_url)
|
||||
path = parsed_url.path
|
||||
method_name = path.rstrip("/").split("/")[-1] if path else ""
|
||||
try:
|
||||
path = parsed_url.path or ""
|
||||
method_name = path.removesuffix("/").lower()
|
||||
if method_name in connection_methods:
|
||||
client_factory = connection_methods[method_name]
|
||||
self.connect_server(client_factory, method_name)
|
||||
except KeyError:
|
||||
else:
|
||||
try:
|
||||
logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.")
|
||||
self.connect_server(sse_client, "sse")
|
||||
except MCPConnectionError:
|
||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
def connect_server(
|
||||
|
|
@ -91,7 +93,7 @@ class MCPClient:
|
|||
else {}
|
||||
)
|
||||
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
||||
if self._streams_context is None:
|
||||
if not self._streams_context:
|
||||
raise MCPConnectionError("Failed to create connection context")
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
|
|
@ -141,10 +143,11 @@ class MCPClient:
|
|||
try:
|
||||
# ExitStack will handle proper cleanup of all managed context managers
|
||||
self.exit_stack.close()
|
||||
except Exception as e:
|
||||
logging.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
finally:
|
||||
self._session = None
|
||||
self._session_context = None
|
||||
self._streams_context = None
|
||||
self._initialized = False
|
||||
except Exception as e:
|
||||
logging.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
|
|
@ -142,11 +142,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
workflow_metadata = {
|
||||
"workflow_id": trace_info.workflow_run_id or "",
|
||||
"workflow_run_id": trace_info.workflow_run_id or "",
|
||||
"message_id": trace_info.message_id or "",
|
||||
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
|
||||
"status": trace_info.workflow_run_status or "",
|
||||
|
|
@ -156,7 +153,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
trace_id = uuid_to_trace_id(trace_info.workflow_run_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
|
|
@ -213,7 +210,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
if model:
|
||||
node_metadata["ls_model_name"] = model
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {})
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
if usage_data:
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
|
|
@ -236,31 +233,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if node_execution.node_type == "llm":
|
||||
llm_attributes: dict[str, Any] = {
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
}
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
|
||||
if model:
|
||||
node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {})
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
|
||||
outputs = (
|
||||
json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
)
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
|
||||
)
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
|
||||
)
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0)
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
|
||||
node_span.set_attributes(llm_attributes)
|
||||
finally:
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
finally:
|
||||
|
|
@ -352,25 +352,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
if isinstance(trace_info.inputs, list):
|
||||
for i, msg in enumerate(trace_info.inputs):
|
||||
if isinstance(msg, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
|
||||
"role", "user"
|
||||
)
|
||||
# todo: handle assistant and tool role messages, as they don't always
|
||||
# have a text field, but may have a tool_calls field instead
|
||||
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
|
||||
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
|
||||
elif isinstance(trace_info.inputs, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
elif isinstance(trace_info.inputs, str):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
|
||||
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
|
||||
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
|
||||
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
|
||||
|
|
@ -724,3 +706,24 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
.all()
|
||||
)
|
||||
return workflow_nodes
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
"""Helper method to construct LLM attributes with passed prompts."""
|
||||
attributes = {}
|
||||
if isinstance(prompts, list):
|
||||
for i, msg in enumerate(prompts):
|
||||
if isinstance(msg, dict):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
|
||||
# todo: handle assistant and tool role messages, as they don't always
|
||||
# have a text field, but may have a tool_calls field instead
|
||||
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
|
||||
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
|
||||
elif isinstance(prompts, dict):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
elif isinstance(prompts, str):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
|
||||
return attributes
|
||||
|
|
|
|||
|
|
@ -233,6 +233,12 @@ class AnalyticdbVectorOpenAPI:
|
|||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
|
||||
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
|
|
@ -245,7 +251,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
filter=where_clause,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
|
|
@ -265,6 +271,11 @@ class AnalyticdbVectorOpenAPI:
|
|||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
|
|
@ -277,7 +288,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
filter=where_clause,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
|
|
|
|||
|
|
@ -147,10 +147,17 @@ class ElasticSearchVector(BaseVector):
|
|||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
if document_ids_filter:
|
||||
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore
|
||||
query_str = {
|
||||
"bool": {
|
||||
"must": {"match": {Field.CONTENT_KEY.value: query}},
|
||||
"filter": {"terms": {"metadata.document_id": document_ids_filter}},
|
||||
}
|
||||
}
|
||||
|
||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||
splits = text.split()
|
||||
else:
|
||||
splits = text.split(separator)
|
||||
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
|
||||
else:
|
||||
splits = list(text)
|
||||
splits = [s for s in splits if (s not in {"", "\n"})]
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from core.tools.plugin_tool.tool import PluginTool
|
|||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
|
|
|||
|
|
@ -270,7 +270,14 @@ class AgentNode(BaseNode):
|
|||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
runtime_variable_pool = variable_pool if self._node_data.version != "1" else None
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version != "1":
|
||||
runtime_variable_pool = variable_pool
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@ class AgentNodeData(BaseNodeData):
|
|||
agent_strategy_name: str
|
||||
agent_strategy_label: str # redundancy
|
||||
memory: MemoryConfig | None = None
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
# and requires using the legacy parameter parsing rules.
|
||||
tool_node_version: str | None = None
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
value: Union[list[str], list[ToolSelector], Any]
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
|||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||
metadata_model_config: ModelConfig
|
||||
metadata_model_config: Optional[ModelConfig] = None
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
|
|
|
|||
|
|
@ -509,6 +509,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
if node_data.metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance and fetch model config
|
||||
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
|
||||
# fetch prompt messages
|
||||
|
|
@ -701,7 +703,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
)
|
||||
|
||||
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
|
||||
model_mode = ModelMode(node_data.metadata_model_config.mode)
|
||||
model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
|
||||
input_text = query
|
||||
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
|
|
|
|||
|
|
@ -75,6 +75,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
|||
},
|
||||
NodeType.TOOL: {
|
||||
LATEST_VERSION: ToolNode,
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
||||
"2": ToolNode,
|
||||
"1": ToolNode,
|
||||
},
|
||||
|
|
@ -125,6 +128,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
|||
},
|
||||
NodeType.AGENT: {
|
||||
LATEST_VERSION: AgentNode,
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
||||
"2": AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -59,6 +59,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
|||
return typ
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
# and requires using the legacy parameter parsing rules.
|
||||
tool_node_version: str | None = None
|
||||
|
||||
@field_validator("tool_parameters", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -70,7 +70,13 @@ class ToolNode(BaseNode):
|
|||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version != "1":
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ class WorkflowCycleManager:
|
|||
workflow_version=self._workflow_info.version,
|
||||
graph=self._workflow_info.graph_data,
|
||||
inputs=inputs,
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
return self._save_and_cache_workflow_execution(execution)
|
||||
|
|
@ -356,7 +356,7 @@ class WorkflowCycleManager:
|
|||
created_at: Optional[datetime] = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a node execution from an event."""
|
||||
now = datetime.now(UTC).replace(tzinfo=None)
|
||||
now = naive_utc_now()
|
||||
created_at = created_at or now
|
||||
|
||||
metadata = {
|
||||
|
|
@ -403,7 +403,7 @@ class WorkflowCycleManager:
|
|||
handle_special_values: bool = False,
|
||||
) -> None:
|
||||
"""Update node execution with completion data."""
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
finished_at = naive_utc_now()
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
# Process data
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
|
@ -8,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
|||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from events.event_handlers.document_index_event import document_index_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Document
|
||||
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ def handle(sender, **kwargs):
|
|||
raise NotFound("Document not found")
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
|
||||
|
|
@ -8,6 +8,7 @@ from azure.storage.blob import AccountSasPermissions, BlobServiceClient, Resourc
|
|||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class AzureBlobStorage(BaseStorage):
|
||||
|
|
@ -78,7 +79,7 @@ class AzureBlobStorage(BaseStorage):
|
|||
account_key=self.account_key or "",
|
||||
resource_types=ResourceTypes(service=True, container=True, object=True),
|
||||
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
|
||||
expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
|
||||
expiry=naive_utc_now() + timedelta(hours=1),
|
||||
)
|
||||
redis_client.set(cache_key, sas_token, ex=3000)
|
||||
return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)
|
||||
|
|
|
|||
|
|
@ -149,9 +149,7 @@ def _build_from_local_file(
|
|||
if strict_type_validation and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
file_type = (
|
||||
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
|
||||
)
|
||||
file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
|
|
@ -200,9 +198,7 @@ def _build_from_remote_url(
|
|||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
file_type = (
|
||||
FileType(specified_type)
|
||||
if specified_type and specified_type != FileType.CUSTOM.value
|
||||
else detected_file_type
|
||||
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
|
||||
)
|
||||
|
||||
return File(
|
||||
|
|
@ -287,9 +283,7 @@ def _build_from_tool_file(
|
|||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
file_type = (
|
||||
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
|
||||
)
|
||||
file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import datetime
|
||||
import urllib.parse
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -6,6 +5,7 @@ import requests
|
|||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ class NotionOAuth(OAuthDataSource):
|
|||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
|
|
@ -115,7 +115,7 @@ class NotionOAuth(OAuthDataSource):
|
|||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
|
|
@ -154,7 +154,7 @@ class NotionOAuth(OAuthDataSource):
|
|||
}
|
||||
data_source_binding.source_info = new_source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
"""update models
|
||||
|
||||
Revision ID: 1a83934ad6d1
|
||||
Revises: 71f5020c6470
|
||||
Create Date: 2025-07-21 09:35:48.774794
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1a83934ad6d1'
|
||||
down_revision = '71f5020c6470'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('server_identifier',
|
||||
existing_type=sa.VARCHAR(length=24),
|
||||
type_=sa.String(length=64),
|
||||
existing_nullable=False)
|
||||
|
||||
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
|
||||
batch_op.alter_column('tool_name',
|
||||
existing_type=sa.VARCHAR(length=40),
|
||||
type_=sa.String(length=128),
|
||||
existing_nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
|
||||
batch_op.alter_column('tool_name',
|
||||
existing_type=sa.String(length=128),
|
||||
type_=sa.VARCHAR(length=40),
|
||||
existing_nullable=False)
|
||||
|
||||
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('server_identifier',
|
||||
existing_type=sa.String(length=64),
|
||||
type_=sa.VARCHAR(length=24),
|
||||
existing_nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -287,7 +287,7 @@ class Dataset(Base):
|
|||
@staticmethod
|
||||
def gen_collection_name_by_id(dataset_id: str) -> str:
|
||||
normalized_dataset_id = dataset_id.replace("-", "_")
|
||||
return f"Vector_index_{normalized_dataset_id}_Node"
|
||||
return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node"
|
||||
|
||||
|
||||
class DatasetProcessRule(Base):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from celery import states # type: ignore
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.base import Base
|
||||
|
||||
from .engine import db
|
||||
|
|
@ -18,8 +17,8 @@ class CeleryTask(Base):
|
|||
result = db.Column(db.PickleType, nullable=True)
|
||||
date_done = db.Column(
|
||||
db.DateTime,
|
||||
default=lambda: datetime.now(UTC).replace(tzinfo=None),
|
||||
onupdate=lambda: datetime.now(UTC).replace(tzinfo=None),
|
||||
default=lambda: naive_utc_now(),
|
||||
onupdate=lambda: naive_utc_now(),
|
||||
nullable=True,
|
||||
)
|
||||
traceback = db.Column(db.Text, nullable=True)
|
||||
|
|
@ -39,4 +38,4 @@ class CeleryTaskSet(Base):
|
|||
id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True)
|
||||
taskset_id = db.Column(db.String(155), unique=True)
|
||||
result = db.Column(db.PickleType, nullable=True)
|
||||
date_done = db.Column(db.DateTime, default=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True)
|
||||
date_done = db.Column(db.DateTime, default=lambda: naive_utc_now(), nullable=True)
|
||||
|
|
|
|||
|
|
@ -253,7 +253,7 @@ class MCPToolProvider(Base):
|
|||
# name of the mcp provider
|
||||
name: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
# server identifier of the mcp provider
|
||||
server_identifier: Mapped[str] = mapped_column(db.String(24), nullable=False)
|
||||
server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False)
|
||||
# encrypted url of the mcp provider
|
||||
server_url: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
# hash of server_url for uniqueness check
|
||||
|
|
@ -357,7 +357,7 @@ class ToolModelInvoke(Base):
|
|||
# type
|
||||
tool_type = db.Column(db.String(40), nullable=False)
|
||||
# tool name
|
||||
tool_name = db.Column(db.String(40), nullable=False)
|
||||
tool_name = db.Column(db.String(128), nullable=False)
|
||||
# invoke parameters
|
||||
model_parameters = db.Column(db.Text, nullable=False)
|
||||
# prompt messages
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
|
@ -16,6 +16,7 @@ from core.variables.variables import FloatVariable, IntegerVariable, StringVaria
|
|||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_tenant_id
|
||||
|
||||
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
|
||||
|
|
@ -139,7 +140,7 @@ class Workflow(Base):
|
|||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime,
|
||||
nullable=False,
|
||||
default=datetime.now(UTC).replace(tzinfo=None),
|
||||
default=naive_utc_now(),
|
||||
server_onupdate=func.current_timestamp(),
|
||||
)
|
||||
_environment_variables: Mapped[str] = mapped_column(
|
||||
|
|
@ -185,7 +186,7 @@ class Workflow(Base):
|
|||
workflow.rag_pipeline_variables = rag_pipeline_variables or []
|
||||
workflow.marked_name = marked_name
|
||||
workflow.marked_comment = marked_comment
|
||||
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.created_at = naive_utc_now()
|
||||
workflow.updated_at = workflow.created_at
|
||||
return workflow
|
||||
|
||||
|
|
@ -938,7 +939,7 @@ _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
|
|||
|
||||
|
||||
def _naive_utc_datetime():
|
||||
return datetime.now(UTC).replace(tzinfo=None)
|
||||
return naive_utc_now()
|
||||
|
||||
|
||||
class WorkflowDraftVariable(Base):
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from constants.languages import language_timezone_mapping, languages
|
|||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import RateLimiter, TokenManager
|
||||
from libs.passport import PassportService
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
|
|
@ -135,8 +136,8 @@ class AccountService:
|
|||
available_ta.current = True
|
||||
db.session.commit()
|
||||
|
||||
if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10):
|
||||
account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
if naive_utc_now() - account.last_active_at > timedelta(minutes=10):
|
||||
account.last_active_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return cast(Account, account)
|
||||
|
|
@ -180,7 +181,7 @@ class AccountService:
|
|||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -318,7 +319,7 @@ class AccountService:
|
|||
# If it exists, update the record
|
||||
account_integrate.open_id = open_id
|
||||
account_integrate.encrypted_token = "" # todo
|
||||
account_integrate.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
account_integrate.updated_at = naive_utc_now()
|
||||
else:
|
||||
# If it does not exist, create a new record
|
||||
account_integrate = AccountIntegrate(
|
||||
|
|
@ -353,7 +354,7 @@ class AccountService:
|
|||
@staticmethod
|
||||
def update_login_info(account: Account, *, ip_address: str) -> None:
|
||||
"""Update last login time and ip"""
|
||||
account.last_login_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
account.last_login_at = naive_utc_now()
|
||||
account.last_login_ip = ip_address
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
|
@ -1066,15 +1067,6 @@ class TenantService:
|
|||
target_member_join.role = new_role
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
|
||||
"""Dissolve tenant"""
|
||||
if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
|
||||
raise NoPermissionError("No permission to dissolve tenant.")
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
|
||||
db.session.delete(tenant)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_custom_config(tenant_id: str) -> dict:
|
||||
tenant = db.get_or_404(Tenant, tenant_id)
|
||||
|
|
@ -1117,7 +1109,7 @@ class RegisterService:
|
|||
)
|
||||
|
||||
account.last_login_ip = ip_address
|
||||
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
|
||||
|
||||
|
|
@ -1158,7 +1150,7 @@ class RegisterService:
|
|||
is_setup=is_setup,
|
||||
)
|
||||
account.status = AccountStatus.ACTIVE.value if not status else status.value
|
||||
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
if open_id is not None and provider is not None:
|
||||
AccountService.link_account_integrate(provider, open_id, account)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, cast
|
||||
|
||||
from flask_login import current_user
|
||||
|
|
@ -17,6 +16,7 @@ from core.tools.tool_manager import ToolManager
|
|||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, AppModelConfig, Site
|
||||
from models.tools import ApiToolProvider
|
||||
|
|
@ -235,7 +235,7 @@ class AppService:
|
|||
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)
|
||||
app.max_active_requests = args.get("max_active_requests")
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
app.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
|
@ -249,7 +249,7 @@ class AppService:
|
|||
"""
|
||||
app.name = name
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
app.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
|
@ -265,7 +265,7 @@ class AppService:
|
|||
app.icon = icon
|
||||
app.icon_background = icon_background
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
app.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
|
@ -282,7 +282,7 @@ class AppService:
|
|||
|
||||
app.enable_site = enable_site
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
app.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
|
@ -299,7 +299,7 @@ class AppService:
|
|||
|
||||
app.enable_api = enable_api
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
app.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from collections.abc import Callable, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import asc, desc, func, or_, select
|
||||
|
|
@ -8,6 +7,7 @@ from sqlalchemy.orm import Session
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import ConversationVariable
|
||||
from models.account import Account
|
||||
|
|
@ -113,7 +113,7 @@ class ConversationService:
|
|||
return cls.auto_generate_name(app_model, conversation)
|
||||
else:
|
||||
conversation.name = name
|
||||
conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
conversation.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return conversation
|
||||
|
|
@ -169,7 +169,7 @@ class ConversationService:
|
|||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
conversation.is_deleted = True
|
||||
conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
conversation.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from events.document_event import document_was_deleted
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
|
|
@ -484,7 +485,7 @@ class DatasetService:
|
|||
|
||||
# Add metadata fields
|
||||
filtered_data["updated_by"] = user.id
|
||||
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
filtered_data["updated_at"] = naive_utc_now()
|
||||
# update Retrieval model
|
||||
filtered_data["retrieval_model"] = data["retrieval_model"]
|
||||
# update icon info
|
||||
|
|
@ -1175,7 +1176,7 @@ class DocumentService:
|
|||
# update document to be paused
|
||||
document.is_paused = True
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.paused_at = naive_utc_now()
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
|
@ -11,6 +10,7 @@ from constants import HIDDEN_VALUE
|
|||
from core.helper import ssrf_proxy
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import (
|
||||
Dataset,
|
||||
ExternalKnowledgeApis,
|
||||
|
|
@ -120,7 +120,7 @@ class ExternalDatasetService:
|
|||
external_knowledge_api.description = args.get("description", "")
|
||||
external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False)
|
||||
external_knowledge_api.updated_by = user_id
|
||||
external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
external_knowledge_api.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return external_knowledge_api
|
||||
|
|
|
|||
|
|
@ -70,16 +70,15 @@ class MCPToolManageService:
|
|||
MCPToolProvider.server_url_hash == server_url_hash,
|
||||
MCPToolProvider.server_identifier == server_identifier,
|
||||
),
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing_provider:
|
||||
if existing_provider.name == name:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
elif existing_provider.server_url_hash == server_url_hash:
|
||||
if existing_provider.server_url_hash == server_url_hash:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
elif existing_provider.server_identifier == server_identifier:
|
||||
if existing_provider.server_identifier == server_identifier:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
mcp_tool = MCPToolProvider(
|
||||
|
|
@ -111,15 +110,14 @@ class MCPToolManageService:
|
|||
]
|
||||
|
||||
@classmethod
|
||||
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
|
||||
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
except MCPAuthError as e:
|
||||
except MCPAuthError:
|
||||
raise ValueError("Please auth the tool first")
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
|
@ -184,12 +182,11 @@ class MCPToolManageService:
|
|||
error_msg = str(e.orig)
|
||||
if "unique_mcp_provider_name" in error_msg:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
elif "unique_mcp_provider_server_url" in error_msg:
|
||||
if "unique_mcp_provider_server_url" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
elif "unique_mcp_provider_server_identifier" in error_msg:
|
||||
if "unique_mcp_provider_server_identifier" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
else:
|
||||
raise
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def update_mcp_provider_credentials(
|
||||
|
|
@ -2,7 +2,6 @@ import json
|
|||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -33,6 +32,7 @@ from core.workflow.workflow_entry import WorkflowEntry
|
|||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
|
|
@ -232,7 +232,7 @@ class WorkflowService:
|
|||
workflow.graph = json.dumps(graph)
|
||||
workflow.features = json.dumps(features)
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.updated_at = naive_utc_now()
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.conversation_variables = conversation_variables
|
||||
|
||||
|
|
@ -268,7 +268,7 @@ class WorkflowService:
|
|||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=draft_workflow.type,
|
||||
version=Workflow.version_from_datetime(datetime.now(UTC).replace(tzinfo=None)),
|
||||
version=Workflow.version_from_datetime(naive_utc_now()),
|
||||
graph=draft_workflow.graph,
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
|
|
@ -524,8 +524,8 @@ class WorkflowService:
|
|||
node_type=node.type_,
|
||||
title=node.title,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
finished_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
created_at=naive_utc_now(),
|
||||
finished_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
if run_succeeded and node_run_result:
|
||||
|
|
@ -622,7 +622,7 @@ class WorkflowService:
|
|||
setattr(workflow, field, value)
|
||||
|
||||
workflow.updated_by = account_id
|
||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.updated_at = naive_utc_now()
|
||||
|
||||
return workflow
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
|
@ -8,6 +7,7 @@ from celery import shared_task # type: ignore
|
|||
from configs import dify_config
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
|
@ -53,7 +53,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
|||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
|
@ -68,7 +68,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
|||
|
||||
if document:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -26,8 +26,15 @@ redis_mock.hgetall = MagicMock(return_value={})
|
|||
redis_mock.hdel = MagicMock()
|
||||
redis_mock.incr = MagicMock(return_value=1)
|
||||
|
||||
# Add the API directory to Python path to ensure proper imports
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, PROJECT_DIR)
|
||||
|
||||
# apply the mock to the Redis client in the Flask app
|
||||
redis_patcher = patch("extensions.ext_redis.redis_client", redis_mock)
|
||||
from extensions import ext_redis
|
||||
|
||||
redis_patcher = patch.object(ext_redis, "redis_client", redis_mock)
|
||||
redis_patcher.start()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
# API authentication service test module
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
import pytest
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class ConcreteApiKeyAuth(ApiKeyAuthBase):
|
||||
"""Concrete implementation for testing abstract base class"""
|
||||
|
||||
def validate_credentials(self):
|
||||
return True
|
||||
|
||||
|
||||
class TestApiKeyAuthBase:
|
||||
def test_should_store_credentials_on_init(self):
|
||||
"""Test that credentials are properly stored during initialization"""
|
||||
credentials = {"api_key": "test_key", "auth_type": "bearer"}
|
||||
auth = ConcreteApiKeyAuth(credentials)
|
||||
assert auth.credentials == credentials
|
||||
|
||||
def test_should_not_instantiate_abstract_class(self):
|
||||
"""Test that ApiKeyAuthBase cannot be instantiated directly"""
|
||||
credentials = {"api_key": "test_key"}
|
||||
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
ApiKeyAuthBase(credentials)
|
||||
|
||||
assert "Can't instantiate abstract class" in str(exc_info.value)
|
||||
assert "validate_credentials" in str(exc_info.value)
|
||||
|
||||
def test_should_allow_subclass_implementation(self):
|
||||
"""Test that subclasses can properly implement the abstract method"""
|
||||
credentials = {"api_key": "test_key", "auth_type": "bearer"}
|
||||
auth = ConcreteApiKeyAuth(credentials)
|
||||
|
||||
# Should not raise any exception
|
||||
result = auth.validate_credentials()
|
||||
assert result is True
|
||||
|
||||
def test_should_handle_empty_credentials(self):
|
||||
"""Test initialization with empty credentials"""
|
||||
credentials = {}
|
||||
auth = ConcreteApiKeyAuth(credentials)
|
||||
assert auth.credentials == {}
|
||||
|
||||
def test_should_handle_none_credentials(self):
|
||||
"""Test initialization with None credentials"""
|
||||
credentials = None
|
||||
auth = ConcreteApiKeyAuth(credentials)
|
||||
assert auth.credentials is None
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
|
||||
class TestApiKeyAuthFactory:
|
||||
"""Test cases for ApiKeyAuthFactory"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "auth_class_path"),
|
||||
[
|
||||
(AuthType.FIRECRAWL, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
|
||||
(AuthType.WATERCRAWL, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
|
||||
(AuthType.JINA, "services.auth.jina.jina.JinaAuth"),
|
||||
],
|
||||
)
|
||||
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):
|
||||
"""Test getting auth factory for all valid providers"""
|
||||
with patch(auth_class_path) as mock_auth:
|
||||
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
|
||||
assert auth_class == mock_auth
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_provider",
|
||||
[
|
||||
"invalid_provider",
|
||||
"",
|
||||
None,
|
||||
123,
|
||||
"UNSUPPORTED",
|
||||
],
|
||||
)
|
||||
def test_get_apikey_auth_factory_invalid_providers(self, invalid_provider):
|
||||
"""Test getting auth factory with various invalid providers"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiKeyAuthFactory.get_apikey_auth_factory(invalid_provider)
|
||||
assert str(exc_info.value) == "Invalid provider"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials_return_value", "expected_result"),
|
||||
[
|
||||
(True, True),
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
|
||||
def test_validate_credentials_delegates_to_auth_instance(
|
||||
self, mock_get_factory, credentials_return_value, expected_result
|
||||
):
|
||||
"""Test that validate_credentials delegates to auth instance correctly"""
|
||||
# Arrange
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth_instance.validate_credentials.return_value = credentials_return_value
|
||||
mock_auth_class = MagicMock(return_value=mock_auth_instance)
|
||||
mock_get_factory.return_value = mock_auth_class
|
||||
|
||||
# Act
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
|
||||
result = factory.validate_credentials()
|
||||
|
||||
# Assert
|
||||
assert result is expected_result
|
||||
mock_auth_instance.validate_credentials.assert_called_once()
|
||||
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
|
||||
def test_validate_credentials_propagates_exceptions(self, mock_get_factory):
|
||||
"""Test that exceptions from auth instance are propagated"""
|
||||
# Arrange
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth_instance.validate_credentials.side_effect = Exception("Authentication error")
|
||||
mock_auth_class = MagicMock(return_value=mock_auth_instance)
|
||||
mock_get_factory.return_value = mock_auth_class
|
||||
|
||||
# Act & Assert
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
factory.validate_credentials()
|
||||
assert str(exc_info.value) == "Authentication error"
|
||||
|
|
@ -0,0 +1,382 @@
|
|||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
|
||||
class TestApiKeyAuthService:
|
||||
"""API key authentication service security tests"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test fixtures"""
|
||||
self.tenant_id = "test_tenant_123"
|
||||
self.category = "search"
|
||||
self.provider = "google"
|
||||
self.binding_id = "binding_123"
|
||||
self.mock_credentials = {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}}
|
||||
self.mock_args = {"category": self.category, "provider": self.provider, "credentials": self.mock_credentials}
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_success(self, mock_session):
|
||||
"""Test get provider auth list - success scenario"""
|
||||
# Mock database query result
|
||||
mock_binding = Mock()
|
||||
mock_binding.tenant_id = self.tenant_id
|
||||
mock_binding.provider = self.provider
|
||||
mock_binding.disabled = False
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_binding]
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].tenant_id == self.tenant_id
|
||||
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_empty(self, mock_session):
|
||||
"""Test get provider auth list - empty result"""
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_filters_disabled(self, mock_session):
|
||||
"""Test get provider auth list - filters disabled items"""
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
# Verify filter conditions include disabled.is_(False)
|
||||
filter_call = mock_session.query.return_value.filter.call_args[0]
|
||||
assert len(filter_call) == 2 # tenant_id and disabled filter conditions
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_success(self, mock_encrypter, mock_factory, mock_session):
|
||||
"""Test create provider auth - success scenario"""
|
||||
# Mock successful auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
# Mock encryption
|
||||
encrypted_key = "encrypted_test_key_123"
|
||||
mock_encrypter.encrypt_token.return_value = encrypted_key
|
||||
|
||||
# Mock database operations
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
# Verify factory class calls
|
||||
mock_factory.assert_called_once_with(self.provider, self.mock_credentials)
|
||||
mock_auth_instance.validate_credentials.assert_called_once()
|
||||
|
||||
# Verify encryption calls
|
||||
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, "test_secret_key_123")
|
||||
|
||||
# Verify database operations
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
def test_create_provider_auth_validation_failed(self, mock_factory, mock_session):
|
||||
"""Test create provider auth - validation failed"""
|
||||
# Mock failed auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = False
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
# Verify no database operations when validation fails
|
||||
mock_session.add.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_encrypts_api_key(self, mock_encrypter, mock_factory, mock_session):
|
||||
"""Test create provider auth - ensures API key is encrypted"""
|
||||
# Mock successful auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
# Mock encryption
|
||||
encrypted_key = "encrypted_test_key_123"
|
||||
mock_encrypter.encrypt_token.return_value = encrypted_key
|
||||
|
||||
# Mock database operations
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
args_copy = self.mock_args.copy()
|
||||
original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
|
||||
|
||||
# Verify original key is replaced with encrypted key
|
||||
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore
|
||||
assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore
|
||||
|
||||
# Verify encryption function is called correctly
|
||||
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_success(self, mock_session):
|
||||
"""Test get auth credentials - success scenario"""
|
||||
# Mock database query result
|
||||
mock_binding = Mock()
|
||||
mock_binding.credentials = json.dumps(self.mock_credentials)
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
assert result == self.mock_credentials
|
||||
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_not_found(self, mock_session):
|
||||
"""Test get auth credentials - not found"""
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_filters_correctly(self, mock_session):
|
||||
"""Test get auth credentials - applies correct filters"""
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
# Verify filter conditions are correct
|
||||
filter_call = mock_session.query.return_value.filter.call_args[0]
|
||||
assert len(filter_call) == 4 # tenant_id, category, provider, disabled
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_json_parsing(self, mock_session):
|
||||
"""Test get auth credentials - JSON parsing"""
|
||||
# Mock credentials with special characters
|
||||
special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}}
|
||||
|
||||
mock_binding = Mock()
|
||||
mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False)
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
assert result == special_credentials
|
||||
assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_delete_provider_auth_success(self, mock_session):
|
||||
"""Test delete provider auth - success scenario"""
|
||||
# Mock database query result
|
||||
mock_binding = Mock()
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
|
||||
|
||||
# Verify delete operations
|
||||
mock_session.delete.assert_called_once_with(mock_binding)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_delete_provider_auth_not_found(self, mock_session):
|
||||
"""Test delete provider auth - not found"""
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
|
||||
|
||||
# Verify no delete operations when not found
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_delete_provider_auth_filters_by_tenant(self, mock_session):
|
||||
"""Test delete provider auth - filters by tenant"""
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
|
||||
|
||||
# Verify filter conditions include tenant_id and binding_id
|
||||
filter_call = mock_session.query.return_value.filter.call_args[0]
|
||||
assert len(filter_call) == 2
|
||||
|
||||
def test_validate_api_key_auth_args_success(self):
|
||||
"""Test API key auth args validation - success scenario"""
|
||||
# Should not raise any exception
|
||||
ApiKeyAuthService.validate_api_key_auth_args(self.mock_args)
|
||||
|
||||
def test_validate_api_key_auth_args_missing_category(self):
|
||||
"""Test API key auth args validation - missing category"""
|
||||
args = self.mock_args.copy()
|
||||
del args["category"]
|
||||
|
||||
with pytest.raises(ValueError, match="category is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_category(self):
|
||||
"""Test API key auth args validation - empty category"""
|
||||
args = self.mock_args.copy()
|
||||
args["category"] = ""
|
||||
|
||||
with pytest.raises(ValueError, match="category is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_missing_provider(self):
|
||||
"""Test API key auth args validation - missing provider"""
|
||||
args = self.mock_args.copy()
|
||||
del args["provider"]
|
||||
|
||||
with pytest.raises(ValueError, match="provider is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_provider(self):
|
||||
"""Test API key auth args validation - empty provider"""
|
||||
args = self.mock_args.copy()
|
||||
args["provider"] = ""
|
||||
|
||||
with pytest.raises(ValueError, match="provider is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_missing_credentials(self):
|
||||
"""Test API key auth args validation - missing credentials"""
|
||||
args = self.mock_args.copy()
|
||||
del args["credentials"]
|
||||
|
||||
with pytest.raises(ValueError, match="credentials is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_credentials(self):
|
||||
"""Test API key auth args validation - empty credentials"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"] = None # type: ignore
|
||||
|
||||
with pytest.raises(ValueError, match="credentials is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_invalid_credentials_type(self):
|
||||
"""Test API key auth args validation - invalid credentials type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"] = "not_a_dict"
|
||||
|
||||
with pytest.raises(ValueError, match="credentials must be a dictionary"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_missing_auth_type(self):
|
||||
"""Test API key auth args validation - missing auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
del args["credentials"]["auth_type"] # type: ignore
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_auth_type(self):
|
||||
"""Test API key auth args validation - empty auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = "" # type: ignore
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"malicious_input",
|
||||
[
|
||||
"<script>alert('xss')</script>",
|
||||
"'; DROP TABLE users; --",
|
||||
"../../../etc/passwd",
|
||||
"\\x00\\x00", # null bytes
|
||||
"A" * 10000, # very long input
|
||||
],
|
||||
)
|
||||
def test_validate_api_key_auth_args_malicious_input(self, malicious_input):
|
||||
"""Test API key auth args validation - malicious input"""
|
||||
args = self.mock_args.copy()
|
||||
args["category"] = malicious_input
|
||||
|
||||
# Verify parameter validator doesn't crash on malicious input
|
||||
# Should validate normally rather than raising security-related exceptions
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_database_error_handling(self, mock_encrypter, mock_factory, mock_session):
|
||||
"""Test create provider auth - database error handling"""
|
||||
# Mock successful auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
# Mock encryption
|
||||
mock_encrypter.encrypt_token.return_value = "encrypted_key"
|
||||
|
||||
# Mock database error
|
||||
mock_session.commit.side_effect = Exception("Database error")
|
||||
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_invalid_json(self, mock_session):
|
||||
"""Test get auth credentials - invalid JSON"""
|
||||
# Mock database returning invalid JSON
|
||||
mock_binding = Mock()
|
||||
mock_binding.credentials = "invalid json content"
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
def test_create_provider_auth_factory_exception(self, mock_factory, mock_session):
|
||||
"""Test create provider auth - factory exception"""
|
||||
# Mock factory raising exception
|
||||
mock_factory.side_effect = Exception("Factory error")
|
||||
|
||||
with pytest.raises(Exception, match="Factory error"):
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, mock_session):
|
||||
"""Test create provider auth - encryption exception"""
|
||||
# Mock successful auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
# Mock encryption exception
|
||||
mock_encrypter.encrypt_token.side_effect = Exception("Encryption error")
|
||||
|
||||
with pytest.raises(Exception, match="Encryption error"):
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
def test_validate_api_key_auth_args_none_input(self):
|
||||
"""Test API key auth args validation - None input"""
|
||||
with pytest.raises(TypeError):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(None)
|
||||
|
||||
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
|
||||
"""Test API key auth args validation - dict credentials with list auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string
|
||||
|
||||
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
|
||||
# So this should not raise exception, this test should pass
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||
|
||||
|
||||
class TestFirecrawlAuth:
|
||||
@pytest.fixture
|
||||
def valid_credentials(self):
|
||||
"""Fixture for valid bearer credentials"""
|
||||
return {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
|
||||
@pytest.fixture
|
||||
def auth_instance(self, valid_credentials):
|
||||
"""Fixture for FirecrawlAuth instance with valid credentials"""
|
||||
return FirecrawlAuth(valid_credentials)
|
||||
|
||||
def test_should_initialize_with_valid_bearer_credentials(self, valid_credentials):
|
||||
"""Test successful initialization with valid bearer credentials"""
|
||||
auth = FirecrawlAuth(valid_credentials)
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.base_url == "https://api.firecrawl.dev"
|
||||
assert auth.credentials == valid_credentials
|
||||
|
||||
def test_should_initialize_with_custom_base_url(self):
|
||||
"""Test initialization with custom base URL"""
|
||||
credentials = {
|
||||
"auth_type": "bearer",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
|
||||
}
|
||||
auth = FirecrawlAuth(credentials)
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.base_url == "https://custom.firecrawl.dev"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("auth_type", "expected_error"),
|
||||
[
|
||||
("basic", "Invalid auth type, Firecrawl auth type must be Bearer"),
|
||||
("x-api-key", "Invalid auth type, Firecrawl auth type must be Bearer"),
|
||||
("", "Invalid auth type, Firecrawl auth type must be Bearer"),
|
||||
],
|
||||
)
|
||||
def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error):
|
||||
"""Test that non-bearer auth types raise ValueError"""
|
||||
credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
FirecrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials", "expected_error"),
|
||||
[
|
||||
({"auth_type": "bearer", "config": {}}, "No API key provided"),
|
||||
({"auth_type": "bearer"}, "No API key provided"),
|
||||
({"auth_type": "bearer", "config": {"api_key": ""}}, "No API key provided"),
|
||||
({"auth_type": "bearer", "config": {"api_key": None}}, "No API key provided"),
|
||||
],
|
||||
)
|
||||
def test_should_raise_error_for_missing_api_key(self, credentials, expected_error):
|
||||
"""Test that missing or empty API key raises ValueError"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
FirecrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = auth_instance.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
expected_data = {
|
||||
"url": "https://example.com",
|
||||
"includePaths": [],
|
||||
"excludePaths": [],
|
||||
"limit": 1,
|
||||
"scrapeOptions": {"onlyMainContent": True},
|
||||
}
|
||||
mock_post.assert_called_once_with(
|
||||
"https://api.firecrawl.dev/v1/crawl",
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"},
|
||||
json=expected_data,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "error_message"),
|
||||
[
|
||||
(402, "Payment required"),
|
||||
(409, "Conflict error"),
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = {"error": error_message}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "response_text", "has_json_error", "expected_error_contains"),
|
||||
[
|
||||
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
|
||||
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
"""Test handling of unexpected errors with various response formats"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.text = response_text
|
||||
if has_json_error:
|
||||
mock_response.json.side_effect = Exception("Not JSON")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert expected_error_contains in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
[
|
||||
(requests.ConnectionError, "Network error"),
|
||||
(requests.Timeout, "Request timeout"),
|
||||
(requests.ReadTimeout, "Read timeout"),
|
||||
(requests.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_post.side_effect = exception_type(exception_message)
|
||||
|
||||
with pytest.raises(exception_type) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert exception_message in str(exc_info.value)
|
||||
|
||||
def test_should_not_expose_api_key_in_error_messages(self):
|
||||
"""Test that API key is not exposed in error messages"""
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}
|
||||
auth = FirecrawlAuth(credentials)
|
||||
|
||||
# Verify API key is stored but not in any error message
|
||||
assert auth.api_key == "super_secret_key_12345"
|
||||
|
||||
# Test various error scenarios don't expose the key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_post):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {
|
||||
"auth_type": "bearer",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
|
||||
}
|
||||
auth = FirecrawlAuth(credentials)
|
||||
result = auth.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds")
|
||||
|
||||
with pytest.raises(requests.Timeout) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
|
||||
# Verify the timeout exception is raised with original message
|
||||
assert "timed out" in str(exc_info.value)
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.jina.jina import JinaAuth
|
||||
|
||||
|
||||
class TestJinaAuth:
|
||||
def test_should_initialize_with_valid_bearer_credentials(self):
|
||||
"""Test successful initialization with valid bearer credentials"""
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.credentials == credentials
|
||||
|
||||
def test_should_raise_error_for_invalid_auth_type(self):
|
||||
"""Test that non-bearer auth type raises ValueError"""
|
||||
credentials = {"auth_type": "basic", "config": {"api_key": "test_api_key_123"}}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "Invalid auth type, Jina Reader auth type must be Bearer"
|
||||
|
||||
def test_should_raise_error_for_missing_api_key(self):
|
||||
"""Test that missing API key raises ValueError"""
|
||||
credentials = {"auth_type": "bearer", "config": {}}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "No API key provided"
|
||||
|
||||
def test_should_raise_error_for_missing_config(self):
|
||||
"""Test that missing config section raises ValueError"""
|
||||
credentials = {"auth_type": "bearer"}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "No API key provided"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
result = auth.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"https://r.jina.ai",
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"},
|
||||
json={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_http_402_error(self, mock_post):
|
||||
"""Test handling of 402 Payment Required error"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 402
|
||||
mock_response.json.return_value = {"error": "Payment required"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_http_409_error(self, mock_post):
|
||||
"""Test handling of 409 Conflict error"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 409
|
||||
mock_response.json.return_value = {"error": "Conflict error"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_http_500_error(self, mock_post):
|
||||
"""Test handling of 500 Internal Server Error"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.json.return_value = {"error": "Internal server error"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
|
||||
"""Test handling of unexpected errors with text response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 403
|
||||
mock_response.text = '{"error": "Forbidden"}'
|
||||
mock_response.json.side_effect = Exception("Not JSON")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_unexpected_error_without_text(self, mock_post):
|
||||
"""Test handling of unexpected errors without text response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.text = ""
|
||||
mock_response.json.side_effect = Exception("Not JSON")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_network_errors(self, mock_post):
|
||||
"""Test handling of network connection errors"""
|
||||
mock_post.side_effect = requests.ConnectionError("Network error")
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(requests.ConnectionError):
|
||||
auth.validate_credentials()
|
||||
|
||||
def test_should_not_expose_api_key_in_error_messages(self):
|
||||
"""Test that API key is not exposed in error messages"""
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
# Verify API key is stored but not in any error message
|
||||
assert auth.api_key == "super_secret_key_12345"
|
||||
|
||||
# Test various error scenarios don't expose the key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
JinaAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
|
@ -0,0 +1,205 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.watercrawl.watercrawl import WatercrawlAuth
|
||||
|
||||
|
||||
class TestWatercrawlAuth:
|
||||
@pytest.fixture
|
||||
def valid_credentials(self):
|
||||
"""Fixture for valid x-api-key credentials"""
|
||||
return {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123"}}
|
||||
|
||||
@pytest.fixture
|
||||
def auth_instance(self, valid_credentials):
|
||||
"""Fixture for WatercrawlAuth instance with valid credentials"""
|
||||
return WatercrawlAuth(valid_credentials)
|
||||
|
||||
def test_should_initialize_with_valid_x_api_key_credentials(self, valid_credentials):
|
||||
"""Test successful initialization with valid x-api-key credentials"""
|
||||
auth = WatercrawlAuth(valid_credentials)
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.base_url == "https://app.watercrawl.dev"
|
||||
assert auth.credentials == valid_credentials
|
||||
|
||||
def test_should_initialize_with_custom_base_url(self):
|
||||
"""Test initialization with custom base URL"""
|
||||
credentials = {
|
||||
"auth_type": "x-api-key",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"},
|
||||
}
|
||||
auth = WatercrawlAuth(credentials)
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.base_url == "https://custom.watercrawl.dev"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("auth_type", "expected_error"),
|
||||
[
|
||||
("bearer", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
|
||||
("basic", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
|
||||
("", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
|
||||
],
|
||||
)
|
||||
def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error):
|
||||
"""Test that non-x-api-key auth types raise ValueError"""
|
||||
credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WatercrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials", "expected_error"),
|
||||
[
|
||||
({"auth_type": "x-api-key", "config": {}}, "No API key provided"),
|
||||
({"auth_type": "x-api-key"}, "No API key provided"),
|
||||
({"auth_type": "x-api-key", "config": {"api_key": ""}}, "No API key provided"),
|
||||
({"auth_type": "x-api-key", "config": {"api_key": None}}, "No API key provided"),
|
||||
],
|
||||
)
|
||||
def test_should_raise_error_for_missing_api_key(self, credentials, expected_error):
|
||||
"""Test that missing or empty API key raises ValueError"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WatercrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = auth_instance.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
mock_get.assert_called_once_with(
|
||||
"https://app.watercrawl.dev/api/v1/core/crawl-requests/",
|
||||
headers={"Content-Type": "application/json", "X-API-KEY": "test_api_key_123"},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "error_message"),
|
||||
[
|
||||
(402, "Payment required"),
|
||||
(409, "Conflict error"),
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = {"error": error_message}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "response_text", "has_json_error", "expected_error_contains"),
|
||||
[
|
||||
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
|
||||
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
"""Test handling of unexpected errors with various response formats"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.text = response_text
|
||||
if has_json_error:
|
||||
mock_response.json.side_effect = Exception("Not JSON")
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert expected_error_contains in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
[
|
||||
(requests.ConnectionError, "Network error"),
|
||||
(requests.Timeout, "Request timeout"),
|
||||
(requests.ReadTimeout, "Read timeout"),
|
||||
(requests.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_get.side_effect = exception_type(exception_message)
|
||||
|
||||
with pytest.raises(exception_type) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert exception_message in str(exc_info.value)
|
||||
|
||||
def test_should_not_expose_api_key_in_error_messages(self):
|
||||
"""Test that API key is not exposed in error messages"""
|
||||
credentials = {"auth_type": "x-api-key", "config": {"api_key": "super_secret_key_12345"}}
|
||||
auth = WatercrawlAuth(credentials)
|
||||
|
||||
# Verify API key is stored but not in any error message
|
||||
assert auth.api_key == "super_secret_key_12345"
|
||||
|
||||
# Test various error scenarios don't expose the key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_get):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
credentials = {
|
||||
"auth_type": "x-api-key",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"},
|
||||
}
|
||||
auth = WatercrawlAuth(credentials)
|
||||
result = auth.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
assert mock_get.call_args[0][0] == "https://custom.watercrawl.dev/api/v1/core/crawl-requests/"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("base_url", "expected_url"),
|
||||
[
|
||||
("https://app.watercrawl.dev", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
|
||||
("https://app.watercrawl.dev/", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
|
||||
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
|
||||
"""Test that urljoin is used correctly for URL construction with various base URLs"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123", "base_url": base_url}}
|
||||
auth = WatercrawlAuth(credentials)
|
||||
auth.validate_credentials()
|
||||
|
||||
# Verify the correct URL was called
|
||||
assert mock_get.call_args[0][0] == expected_url
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds")
|
||||
|
||||
with pytest.raises(requests.Timeout) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
|
||||
# Verify the timeout exception is raised with original message
|
||||
assert "timed out" in str(exc_info.value)
|
||||
|
|
@ -102,17 +102,16 @@ class TestDatasetServiceUpdateDataset:
|
|||
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
|
||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.datetime") as mock_datetime,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_datetime.datetime.now.return_value = current_time
|
||||
mock_datetime.UTC = datetime.UTC
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
|
||||
yield {
|
||||
"get_dataset": mock_get_dataset,
|
||||
"check_permission": mock_check_perm,
|
||||
"db_session": mock_db,
|
||||
"datetime": mock_datetime,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
}
|
||||
|
||||
|
|
@ -292,7 +291,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
|
|
@ -327,7 +326,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
actual_call_args = mock_dataset_service_dependencies[
|
||||
|
|
@ -365,7 +364,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"collection_binding_id": None,
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
|
|
@ -422,7 +421,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"collection_binding_id": "binding-456",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
|
|
@ -463,7 +462,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"collection_binding_id": "binding-123",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
|
|
@ -525,7 +524,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"collection_binding_id": "binding-789",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
|
|
@ -568,7 +567,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"collection_binding_id": "binding-123",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
|
|
|
|||
|
|
@ -283,11 +283,12 @@ REDIS_CLUSTERS_PASSWORD=
|
|||
# Celery Configuration
|
||||
# ------------------------------
|
||||
|
||||
# Use redis as the broker, and redis db 1 for celery broker.
|
||||
# Format as follows: `redis://<redis_username>:<redis_password>@<redis_host>:<redis_port>/<redis_database>`
|
||||
# Use standalone redis as the broker, and redis db 1 for celery broker. (redis_username is usually set by defualt as empty)
|
||||
# Format as follows: `redis://<redis_username>:<redis_password>@<redis_host>:<redis_port>/<redis_database>`.
|
||||
# Example: redis://:difyai123456@redis:6379/1
|
||||
# If use Redis Sentinel, format as follows: `sentinel://<sentinel_username>:<sentinel_password>@<sentinel_host>:<sentinel_port>/<redis_database>`
|
||||
# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1
|
||||
# If use Redis Sentinel, format as follows: `sentinel://<redis_username>:<redis_password>@<sentinel_host1>:<sentinel_port>/<redis_database>`
|
||||
# For high availability, you can configure multiple Sentinel nodes (if provided) separated by semicolons like below example:
|
||||
# Example: sentinel://:difyai123456@localhost:26379/1;sentinel://:difyai12345@localhost:26379/1;sentinel://:difyai12345@localhost:26379/1
|
||||
CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1
|
||||
CELERY_BACKEND=redis
|
||||
BROKER_USE_SSL=false
|
||||
|
|
@ -412,6 +413,8 @@ SUPABASE_URL=your-server-url
|
|||
# The type of vector store to use.
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
|
||||
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
||||
WEAVIATE_ENDPOINT=http://weaviate:8080
|
||||
|
|
|
|||
|
|
@ -136,6 +136,7 @@ x-shared-env: &shared-api-worker-env
|
|||
SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key}
|
||||
SUPABASE_URL: ${SUPABASE_URL:-your-server-url}
|
||||
VECTOR_STORE: ${VECTOR_STORE:-weaviate}
|
||||
VECTOR_INDEX_NAME_PREFIX: ${VECTOR_INDEX_NAME_PREFIX:-Vector_index}
|
||||
WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080}
|
||||
WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih}
|
||||
QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333}
|
||||
|
|
|
|||
|
|
@ -1 +1,7 @@
|
|||
from dify_client.client import ChatClient, CompletionClient, WorkflowClient, KnowledgeBaseClient, DifyClient
|
||||
from dify_client.client import (
|
||||
ChatClient,
|
||||
CompletionClient,
|
||||
WorkflowClient,
|
||||
KnowledgeBaseClient,
|
||||
DifyClient,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ LABEL maintainer="takatost@gmail.com"
|
|||
# RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories
|
||||
|
||||
RUN apk add --no-cache tzdata
|
||||
RUN npm install -g pnpm@10.11.1
|
||||
RUN npm install -g pnpm@10.13.1
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
'use client'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiAddLine,
|
||||
RiArrowRightLine,
|
||||
} from '@remixicon/react'
|
||||
import Link from 'next/link'
|
||||
|
||||
type CreateAppCardProps = {
|
||||
ref?: React.Ref<HTMLAnchorElement>
|
||||
}
|
||||
|
||||
const CreateAppCard = ({ ref }: CreateAppCardProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className='bg-background-default-dimm flex min-h-[160px] flex-col rounded-xl border-[0.5px]
|
||||
border-components-panel-border transition-all duration-200 ease-in-out'
|
||||
>
|
||||
<Link ref={ref} className='group flex grow cursor-pointer items-start p-4' href='/datasets/create'>
|
||||
<div className='flex items-center gap-3'>
|
||||
<div className='flex h-10 w-10 items-center justify-center rounded-lg border border-dashed border-divider-regular bg-background-default-lighter
|
||||
p-2 group-hover:border-solid group-hover:border-effects-highlight group-hover:bg-background-default-dodge'
|
||||
>
|
||||
<RiAddLine className='h-4 w-4 text-text-tertiary group-hover:text-text-accent' />
|
||||
</div>
|
||||
<div className='system-md-semibold text-text-secondary group-hover:text-text-accent'>{t('dataset.createDataset')}</div>
|
||||
</div>
|
||||
</Link>
|
||||
<div className='system-xs-regular p-4 pt-0 text-text-tertiary'>{t('dataset.createDatasetIntro')}</div>
|
||||
<Link className='group flex cursor-pointer items-center gap-1 rounded-b-xl border-t-[0.5px] border-divider-subtle p-4' href='/datasets/connect'>
|
||||
<div className='system-xs-medium text-text-tertiary group-hover:text-text-accent'>{t('dataset.connectDataset')}</div>
|
||||
<RiArrowRightLine className='h-3.5 w-3.5 text-text-tertiary group-hover:text-text-accent' />
|
||||
</Link>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
CreateAppCard.displayName = 'CreateAppCard'
|
||||
|
||||
export default CreateAppCard
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import React from 'react'
|
||||
import type { ReactNode } from 'react'
|
||||
import SwrInitor from '@/app/components/swr-initor'
|
||||
import SwrInitializer from '@/app/components/swr-initializer'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
import GA, { GaType } from '@/app/components/base/ga'
|
||||
import HeaderWrapper from '@/app/components/header/header-wrapper'
|
||||
|
|
@ -13,7 +13,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
|||
return (
|
||||
<>
|
||||
<GA gaType={GaType.admin} />
|
||||
<SwrInitor>
|
||||
<SwrInitializer>
|
||||
<AppContextProvider>
|
||||
<EventEmitterContextProvider>
|
||||
<ProviderContextProvider>
|
||||
|
|
@ -26,7 +26,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
|||
</ProviderContextProvider>
|
||||
</EventEmitterContextProvider>
|
||||
</AppContextProvider>
|
||||
</SwrInitor>
|
||||
</SwrInitializer>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
'use client'
|
||||
import { useState } from 'react'
|
||||
import useSWR from 'swr'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiGraduationCapFill,
|
||||
|
|
@ -22,6 +23,8 @@ import PremiumBadge from '@/app/components/base/premium-badge'
|
|||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import EmailChangeModal from './email-change-modal'
|
||||
import { validPassword } from '@/config'
|
||||
import { fetchAppList } from '@/service/apps'
|
||||
import type { App } from '@/types/app'
|
||||
|
||||
const titleClassName = `
|
||||
system-sm-semibold text-text-secondary
|
||||
|
|
@ -33,7 +36,9 @@ const descriptionClassName = `
|
|||
export default function AccountPage() {
|
||||
const { t } = useTranslation()
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
const { mutateUserProfile, userProfile, apps } = useAppContext()
|
||||
const { data: appList } = useSWR({ url: '/apps', params: { page: 1, limit: 100, name: '' } }, fetchAppList)
|
||||
const apps = appList?.data || []
|
||||
const { mutateUserProfile, userProfile } = useAppContext()
|
||||
const { isEducationAccount } = useProviderContext()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const [editNameModalVisible, setEditNameModalVisible] = useState(false)
|
||||
|
|
@ -202,7 +207,7 @@ export default function AccountPage() {
|
|||
{!!apps.length && (
|
||||
<Collapse
|
||||
title={`${t('common.account.showAppLength', { length: apps.length })}`}
|
||||
items={apps.map(app => ({ ...app, key: app.id, name: app.name }))}
|
||||
items={apps.map((app: App) => ({ ...app, key: app.id, name: app.name }))}
|
||||
renderItem={renderAppItem}
|
||||
wrapperClassName='mt-2'
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import React from 'react'
|
||||
import type { ReactNode } from 'react'
|
||||
import Header from './header'
|
||||
import SwrInitor from '@/app/components/swr-initor'
|
||||
import SwrInitor from '@/app/components/swr-initializer'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
import GA, { GaType } from '@/app/components/base/ga'
|
||||
import HeaderWrapper from '@/app/components/header/header-wrapper'
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { useTranslation } from 'react-i18next'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useContext, useContextSelector } from 'use-context-selector'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import React, { useCallback, useState } from 'react'
|
||||
import {
|
||||
RiDeleteBinLine,
|
||||
|
|
@ -15,7 +15,7 @@ import AppIcon from '../base/app-icon'
|
|||
import cn from '@/utils/classnames'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import AppsContext, { useAppContext } from '@/context/app-context'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps'
|
||||
import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal'
|
||||
|
|
@ -73,11 +73,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
|||
const [showImportDSLModal, setShowImportDSLModal] = useState<boolean>(false)
|
||||
const [secretEnvList, setSecretEnvList] = useState<EnvironmentVariable[]>([])
|
||||
|
||||
const mutateApps = useContextSelector(
|
||||
AppsContext,
|
||||
state => state.mutateApps,
|
||||
)
|
||||
|
||||
const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({
|
||||
name,
|
||||
icon_type,
|
||||
|
|
@ -106,12 +101,11 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
|||
message: t('app.editDone'),
|
||||
})
|
||||
setAppDetail(app)
|
||||
mutateApps()
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('app.editFailed') })
|
||||
}
|
||||
}, [appDetail, mutateApps, notify, setAppDetail, t])
|
||||
}, [appDetail, notify, setAppDetail, t])
|
||||
|
||||
const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => {
|
||||
if (!appDetail)
|
||||
|
|
@ -131,7 +125,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
|||
message: t('app.newApp.appCreated'),
|
||||
})
|
||||
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
|
||||
mutateApps()
|
||||
onPlanInfoChanged()
|
||||
getRedirection(true, newApp, replace)
|
||||
}
|
||||
|
|
@ -186,7 +179,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
|||
try {
|
||||
await deleteApp(appDetail.id)
|
||||
notify({ type: 'success', message: t('app.appDeleted') })
|
||||
mutateApps()
|
||||
onPlanInfoChanged()
|
||||
setAppDetail()
|
||||
replace('/apps')
|
||||
|
|
@ -198,7 +190,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
|||
})
|
||||
}
|
||||
setShowConfirmDelete(false)
|
||||
}, [appDetail, mutateApps, notify, onPlanInfoChanged, replace, setAppDetail, t])
|
||||
}, [appDetail, notify, onPlanInfoChanged, replace, setAppDetail, t])
|
||||
|
||||
const { isCurrentWorkspaceEditor } = useAppContext()
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import Loading from '@/app/components/base/loading'
|
|||
import Badge from '@/app/components/base/badge'
|
||||
import { useKnowledge } from '@/hooks/use-knowledge'
|
||||
import cn from '@/utils/classnames'
|
||||
import { basePath } from '@/utils/var'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
|
||||
export type ISelectDataSetProps = {
|
||||
|
|
@ -113,7 +112,7 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
|
|||
}}
|
||||
>
|
||||
<span className='text-text-tertiary'>{t('appDebug.feature.dataSet.noDataSet')}</span>
|
||||
<Link href={`${basePath}/datasets/create`} className='font-normal text-text-accent'>{t('appDebug.feature.dataSet.toCreate')}</Link>
|
||||
<Link href='/datasets/create' className='font-normal text-text-accent'>{t('appDebug.feature.dataSet.toCreate')}</Link>
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import { useCallback, useRef, useState } from 'react'
|
|||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useContext, useContextSelector } from 'use-context-selector'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { RiArrowRightLine, RiArrowRightSLine, RiCommandLine, RiCornerDownLeftLine, RiExchange2Fill } from '@remixicon/react'
|
||||
import Link from 'next/link'
|
||||
import { useDebounceFn, useKeyPress } from 'ahooks'
|
||||
|
|
@ -15,7 +15,7 @@ import Button from '@/app/components/base/button'
|
|||
import Divider from '@/app/components/base/divider'
|
||||
import cn from '@/utils/classnames'
|
||||
import { basePath } from '@/utils/var'
|
||||
import AppsContext, { useAppContext } from '@/context/app-context'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import type { AppMode } from '@/types/app'
|
||||
|
|
@ -41,7 +41,6 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps)
|
|||
const { t } = useTranslation()
|
||||
const { push } = useRouter()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const mutateApps = useContextSelector(AppsContext, state => state.mutateApps)
|
||||
|
||||
const [appMode, setAppMode] = useState<AppMode>('advanced-chat')
|
||||
const [appIcon, setAppIcon] = useState<AppIconSelection>({ type: 'emoji', icon: '🤖', background: '#FFEAD5' })
|
||||
|
|
@ -80,7 +79,6 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps)
|
|||
notify({ type: 'success', message: t('app.newApp.appCreated') })
|
||||
onSuccess()
|
||||
onClose()
|
||||
mutateApps()
|
||||
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
|
||||
getRedirection(isCurrentWorkspaceEditor, app, push)
|
||||
}
|
||||
|
|
@ -88,7 +86,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps)
|
|||
notify({ type: 'error', message: t('app.newApp.appCreateFailed') })
|
||||
}
|
||||
isCreatingRef.current = false
|
||||
}, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, mutateApps, push, isCurrentWorkspaceEditor])
|
||||
}, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, push, isCurrentWorkspaceEditor])
|
||||
|
||||
const { run: handleCreateApp } = useDebounceFn(onCreate, { wait: 300 })
|
||||
useKeyPress(['meta.enter', 'ctrl.enter'], () => {
|
||||
|
|
@ -298,7 +296,7 @@ function AppTypeCard({ icon, title, description, active, onClick }: AppTypeCardP
|
|||
>
|
||||
{icon}
|
||||
<div className='system-sm-semibold mb-0.5 mt-2 text-text-secondary'>{title}</div>
|
||||
<div className='system-xs-regular text-text-tertiary'>{description}</div>
|
||||
<div className='system-xs-regular line-clamp-2 text-text-tertiary' title={description}>{description}</div>
|
||||
</div>
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -90,10 +90,10 @@ const Embedded = ({ siteInfo, isShow, onClose, appBaseUrl, accessToken, classNam
|
|||
const [option, setOption] = useState<Option>('iframe')
|
||||
const [isCopied, setIsCopied] = useState<OptionStatus>({ iframe: false, scripts: false, chromePlugin: false })
|
||||
|
||||
const { langeniusVersionInfo } = useAppContext()
|
||||
const { langGeniusVersionInfo } = useAppContext()
|
||||
const themeBuilder = useThemeContext()
|
||||
themeBuilder.buildTheme(siteInfo?.chat_color_theme ?? null, siteInfo?.chat_color_theme_inverted ?? false)
|
||||
const isTestEnv = langeniusVersionInfo.current_env === 'TESTING' || langeniusVersionInfo.current_env === 'DEVELOPMENT'
|
||||
const isTestEnv = langGeniusVersionInfo.current_env === 'TESTING' || langGeniusVersionInfo.current_env === 'DEVELOPMENT'
|
||||
const onClickCopy = () => {
|
||||
if (option === 'chromePlugin') {
|
||||
const splitUrl = OPTION_MAP[option].getContent(appBaseUrl, accessToken).split(': ')
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
'use client'
|
||||
|
||||
import React, { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { useContext, useContextSelector } from 'use-context-selector'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLine } from '@remixicon/react'
|
||||
|
|
@ -11,7 +11,7 @@ import Toast, { ToastContext } from '@/app/components/base/toast'
|
|||
import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps'
|
||||
import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import AppsContext, { useAppContext } from '@/context/app-context'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import type { HtmlContentProps } from '@/app/components/base/popover'
|
||||
import CustomPopover from '@/app/components/base/popover'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
|
|
@ -65,11 +65,6 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
const { onPlanInfoChanged } = useProviderContext()
|
||||
const { push } = useRouter()
|
||||
|
||||
const mutateApps = useContextSelector(
|
||||
AppsContext,
|
||||
state => state.mutateApps,
|
||||
)
|
||||
|
||||
const [showEditModal, setShowEditModal] = useState(false)
|
||||
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
|
||||
const [showSwitchModal, setShowSwitchModal] = useState<boolean>(false)
|
||||
|
|
@ -83,7 +78,6 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
notify({ type: 'success', message: t('app.appDeleted') })
|
||||
if (onRefresh)
|
||||
onRefresh()
|
||||
mutateApps()
|
||||
onPlanInfoChanged()
|
||||
}
|
||||
catch (e: any) {
|
||||
|
|
@ -93,7 +87,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
})
|
||||
}
|
||||
setShowConfirmDelete(false)
|
||||
}, [app.id, mutateApps, notify, onPlanInfoChanged, onRefresh, t])
|
||||
}, [app.id, notify, onPlanInfoChanged, onRefresh, t])
|
||||
|
||||
const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({
|
||||
name,
|
||||
|
|
@ -122,12 +116,11 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
})
|
||||
if (onRefresh)
|
||||
onRefresh()
|
||||
mutateApps()
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('app.editFailed') })
|
||||
}
|
||||
}, [app.id, mutateApps, notify, onRefresh, t])
|
||||
}, [app.id, notify, onRefresh, t])
|
||||
|
||||
const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => {
|
||||
try {
|
||||
|
|
@ -147,7 +140,6 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
|
||||
if (onRefresh)
|
||||
onRefresh()
|
||||
mutateApps()
|
||||
onPlanInfoChanged()
|
||||
getRedirection(isCurrentWorkspaceEditor, newApp, push)
|
||||
}
|
||||
|
|
@ -195,16 +187,14 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
const onSwitch = () => {
|
||||
if (onRefresh)
|
||||
onRefresh()
|
||||
mutateApps()
|
||||
setShowSwitchModal(false)
|
||||
}
|
||||
|
||||
const onUpdateAccessControl = useCallback(() => {
|
||||
if (onRefresh)
|
||||
onRefresh()
|
||||
mutateApps()
|
||||
setShowAccessControl(false)
|
||||
}, [onRefresh, mutateApps, setShowAccessControl])
|
||||
}, [onRefresh, setShowAccessControl])
|
||||
|
||||
const Operations = (props: HtmlContentProps) => {
|
||||
const { data: userCanAccessApp, isLoading: isGettingUserCanAccessApp } = useGetUserCanAccessApp({ appId: app?.id, enabled: (!!props?.open && systemFeatures.webapp_auth.enabled) })
|
||||
|
|
@ -325,7 +315,6 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
dateFormat: `${t('datasetDocuments.segment.dateTimeFormat')}`,
|
||||
})
|
||||
return `${t('datasetDocuments.segment.editedAt')} ${timeText}`
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [app.updated_at, app.created_at])
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ const Avatar = ({
|
|||
className={cn(textClassName, 'scale-[0.4] text-center text-white')}
|
||||
style={style}
|
||||
>
|
||||
{name[0].toLocaleUpperCase()}
|
||||
{name && name[0].toLocaleUpperCase()}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ const Question: FC<QuestionProps> = ({
|
|||
</div>
|
||||
<div
|
||||
ref={contentRef}
|
||||
className='bg-background-gradient-bg-fill-chat-bubble-bg-3 w-full rounded-2xl px-4 py-3 text-sm text-text-primary'
|
||||
className='w-full rounded-2xl bg-background-gradient-bg-fill-chat-bubble-bg-3 px-4 py-3 text-sm text-text-primary'
|
||||
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
|
||||
>
|
||||
{
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ const AppsFull: FC<{ loc: string; className?: string; }> = ({
|
|||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { plan } = useProviderContext()
|
||||
const { userProfile, langeniusVersionInfo } = useAppContext()
|
||||
const { userProfile, langGeniusVersionInfo } = useAppContext()
|
||||
const isTeam = plan.type === Plan.team
|
||||
const usage = plan.usage.buildApps
|
||||
const total = plan.total.buildApps
|
||||
|
|
@ -62,7 +62,7 @@ const AppsFull: FC<{ loc: string; className?: string; }> = ({
|
|||
)}
|
||||
{plan.type !== Plan.sandbox && plan.type !== Plan.professional && (
|
||||
<Button variant='secondary-accent'>
|
||||
<a target='_blank' rel='noopener noreferrer' href={mailToSupport(userProfile.email, plan.type, langeniusVersionInfo.current_version)}>
|
||||
<a target='_blank' rel='noopener noreferrer' href={mailToSupport(userProfile.email, plan.type, langGeniusVersionInfo.current_version)}>
|
||||
{t('billing.apps.contactUs')}
|
||||
</a>
|
||||
</Button>
|
||||
|
|
|
|||
|
|
@ -43,10 +43,10 @@ Object.defineProperty(globalThis, 'sessionStorage', {
|
|||
value: sessionStorage,
|
||||
})
|
||||
|
||||
const BrowserInitor = ({
|
||||
const BrowserInitializer = ({
|
||||
children,
|
||||
}: { children: React.ReactNode }) => {
|
||||
}: { children: React.ReactElement }) => {
|
||||
return children
|
||||
}
|
||||
|
||||
export default BrowserInitor
|
||||
export default BrowserInitializer
|
||||
|
|
@ -83,7 +83,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
- <code>subchunk_segmentation</code> (object) 子チャンクルール
|
||||
- <code>separator</code> セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは <code>***</code>
|
||||
- <code>max_tokens</code> 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります
|
||||
- <code>chunk_overlap</code> 隣接するチャンク間の重複を定義 (オプション)
|
||||
- <code>chunk_overlap</code> 隣接するチャンク間の重なりを定義 (オプション)
|
||||
</Property>
|
||||
<PropertyInstruction>ナレッジベースにパラメータが設定されていない場合、最初のアップロードには以下のパラメータを提供する必要があります。提供されない場合、デフォルトパラメータが使用されます。</PropertyInstruction>
|
||||
<Property name='retrieval_model' type='object' key='retrieval_model'>
|
||||
|
|
@ -218,7 +218,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
- <code>subchunk_segmentation</code> (object) 子チャンクルール
|
||||
- <code>separator</code> セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは <code>***</code>
|
||||
- <code>max_tokens</code> 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります
|
||||
- <code>chunk_overlap</code> 隣接するチャンク間の重複を定義 (オプション)
|
||||
- <code>chunk_overlap</code> 隣接するチャンク間の重なりを定義 (オプション)
|
||||
</Property>
|
||||
<Property name='file' type='multipart/form-data' key='file'>
|
||||
アップロードする必要があるファイル。
|
||||
|
|
@ -555,7 +555,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
- <code>subchunk_segmentation</code> (object) 子チャンクルール
|
||||
- <code>separator</code> セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは <code>***</code>
|
||||
- <code>max_tokens</code> 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります
|
||||
- <code>chunk_overlap</code> 隣接するチャンク間の重複を定義 (オプション)
|
||||
- <code>chunk_overlap</code> 隣接するチャンク間の重なりを定義 (オプション)
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -657,7 +657,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
- <code>subchunk_segmentation</code> (object) 子チャンクルール
|
||||
- <code>separator</code> セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは <code>***</code>
|
||||
- <code>max_tokens</code> 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります
|
||||
- <code>chunk_overlap</code> 隣接するチャンク間の重複を定義 (オプション)
|
||||
- <code>chunk_overlap</code> 隣接するチャンク間の重なりを定義 (オプション)
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等
|
|||
<Col>
|
||||
根据 workflow 执行 ID 获取 workflow 任务当前执行结果
|
||||
### Path
|
||||
- `workflow_run_id` (string) workflow_run_id,可在流式返回 Chunk 中获取
|
||||
- `workflow_run_id` (string) workflow 执行 ID,可在流式返回 Chunk 中获取
|
||||
### Response
|
||||
- `id` (string) workflow 执行 ID
|
||||
- `workflow_id` (string) 关联的 Workflow ID
|
||||
|
|
|
|||
|
|
@ -12,16 +12,16 @@ import { noop } from 'lodash-es'
|
|||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
type IAccountSettingProps = {
|
||||
langeniusVersionInfo: LangGeniusVersionResponse
|
||||
langGeniusVersionInfo: LangGeniusVersionResponse
|
||||
onCancel: () => void
|
||||
}
|
||||
|
||||
export default function AccountAbout({
|
||||
langeniusVersionInfo,
|
||||
langGeniusVersionInfo,
|
||||
onCancel,
|
||||
}: IAccountSettingProps) {
|
||||
const { t } = useTranslation()
|
||||
const isLatest = langeniusVersionInfo.current_version === langeniusVersionInfo.latest_version
|
||||
const isLatest = langGeniusVersionInfo.current_version === langGeniusVersionInfo.latest_version
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
|
||||
return (
|
||||
|
|
@ -43,7 +43,7 @@ export default function AccountAbout({
|
|||
/>
|
||||
: <DifyLogo size='large' className='mx-auto' />}
|
||||
|
||||
<div className='text-center text-xs font-normal text-text-tertiary'>Version {langeniusVersionInfo?.current_version}</div>
|
||||
<div className='text-center text-xs font-normal text-text-tertiary'>Version {langGeniusVersionInfo?.current_version}</div>
|
||||
<div className='flex flex-col items-center gap-2 text-center text-xs font-normal text-text-secondary'>
|
||||
<div>© {dayjs().year()} LangGenius, Inc., Contributors.</div>
|
||||
<div className='text-text-accent'>
|
||||
|
|
@ -63,8 +63,8 @@ export default function AccountAbout({
|
|||
<div className='text-xs font-medium text-text-tertiary'>
|
||||
{
|
||||
isLatest
|
||||
? t('common.about.latestAvailable', { version: langeniusVersionInfo.latest_version })
|
||||
: t('common.about.nowAvailable', { version: langeniusVersionInfo.latest_version })
|
||||
? t('common.about.latestAvailable', { version: langGeniusVersionInfo.latest_version })
|
||||
: t('common.about.nowAvailable', { version: langGeniusVersionInfo.latest_version })
|
||||
}
|
||||
</div>
|
||||
<div className='flex items-center'>
|
||||
|
|
@ -80,7 +80,7 @@ export default function AccountAbout({
|
|||
!isLatest && !IS_CE_EDITION && (
|
||||
<Button variant='primary' size='small'>
|
||||
<Link
|
||||
href={langeniusVersionInfo.release_notes}
|
||||
href={langGeniusVersionInfo.release_notes}
|
||||
target='_blank' rel='noopener noreferrer'
|
||||
>
|
||||
{t('common.about.updateNow')}
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ export default function AppSelector() {
|
|||
|
||||
const { t } = useTranslation()
|
||||
const docLink = useDocLink()
|
||||
const { userProfile, langeniusVersionInfo, isCurrentWorkspaceOwner } = useAppContext()
|
||||
const { userProfile, langGeniusVersionInfo, isCurrentWorkspaceOwner } = useAppContext()
|
||||
const { isEducationAccount } = useProviderContext()
|
||||
const { setShowAccountSettingModal } = useModalContext()
|
||||
|
||||
|
|
@ -180,8 +180,8 @@ export default function AppSelector() {
|
|||
<RiInformation2Line className='size-4 shrink-0 text-text-tertiary' />
|
||||
<div className='system-md-regular grow px-1 text-text-secondary'>{t('common.userProfile.about')}</div>
|
||||
<div className='flex shrink-0 items-center'>
|
||||
<div className='system-xs-regular mr-2 text-text-tertiary'>{langeniusVersionInfo.current_version}</div>
|
||||
<Indicator color={langeniusVersionInfo.current_version === langeniusVersionInfo.latest_version ? 'green' : 'orange'} />
|
||||
<div className='system-xs-regular mr-2 text-text-tertiary'>{langGeniusVersionInfo.current_version}</div>
|
||||
<Indicator color={langGeniusVersionInfo.current_version === langGeniusVersionInfo.latest_version ? 'green' : 'orange'} />
|
||||
</div>
|
||||
</div>
|
||||
</MenuItem>
|
||||
|
|
@ -217,7 +217,7 @@ export default function AppSelector() {
|
|||
}
|
||||
</Menu>
|
||||
{
|
||||
aboutVisible && <AccountAbout onCancel={() => setAboutVisible(false)} langeniusVersionInfo={langeniusVersionInfo} />
|
||||
aboutVisible && <AccountAbout onCancel={() => setAboutVisible(false)} langGeniusVersionInfo={langGeniusVersionInfo} />
|
||||
}
|
||||
</div >
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ export default function Support() {
|
|||
`
|
||||
const { t } = useTranslation()
|
||||
const { plan } = useProviderContext()
|
||||
const { userProfile, langeniusVersionInfo } = useAppContext()
|
||||
const { userProfile, langGeniusVersionInfo } = useAppContext()
|
||||
const canEmailSupport = plan.type === Plan.professional || plan.type === Plan.team || plan.type === Plan.enterprise
|
||||
|
||||
return <Menu as="div" className="relative h-full w-full">
|
||||
|
|
@ -53,7 +53,7 @@ export default function Support() {
|
|||
className={cn(itemClassName, 'group justify-between',
|
||||
'data-[active]:bg-state-base-hover',
|
||||
)}
|
||||
href={mailToSupport(userProfile.email, plan.type, langeniusVersionInfo.current_version)}
|
||||
href={mailToSupport(userProfile.email, plan.type, langGeniusVersionInfo.current_version)}
|
||||
target='_blank' rel='noopener noreferrer'>
|
||||
<RiMailSendLine className='size-4 shrink-0 text-text-tertiary' />
|
||||
<div className='system-md-regular grow px-1 text-text-secondary'>{t('common.userProfile.emailSupport')}</div>
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ const headerEnvClassName: { [k: string]: string } = {
|
|||
|
||||
const EnvNav = () => {
|
||||
const { t } = useTranslation()
|
||||
const { langeniusVersionInfo } = useAppContext()
|
||||
const showEnvTag = langeniusVersionInfo.current_env === 'TESTING' || langeniusVersionInfo.current_env === 'DEVELOPMENT'
|
||||
const { langGeniusVersionInfo } = useAppContext()
|
||||
const showEnvTag = langGeniusVersionInfo.current_env === 'TESTING' || langGeniusVersionInfo.current_env === 'DEVELOPMENT'
|
||||
|
||||
if (!showEnvTag)
|
||||
return null
|
||||
|
|
@ -21,10 +21,10 @@ const EnvNav = () => {
|
|||
return (
|
||||
<div className={`
|
||||
mr-1 flex h-[22px] items-center rounded-md border px-2 text-xs font-medium
|
||||
${headerEnvClassName[langeniusVersionInfo.current_env]}
|
||||
${headerEnvClassName[langGeniusVersionInfo.current_env]}
|
||||
`}>
|
||||
{
|
||||
langeniusVersionInfo.current_env === 'TESTING' && (
|
||||
langGeniusVersionInfo.current_env === 'TESTING' && (
|
||||
<>
|
||||
<Beaker02 className='h-3 w-3' />
|
||||
<div className='ml-1 max-[1280px]:hidden'>{t('common.environment.testing')}</div>
|
||||
|
|
@ -32,7 +32,7 @@ const EnvNav = () => {
|
|||
)
|
||||
}
|
||||
{
|
||||
langeniusVersionInfo.current_env === 'DEVELOPMENT' && (
|
||||
langGeniusVersionInfo.current_env === 'DEVELOPMENT' && (
|
||||
<>
|
||||
<TerminalSquare className='h-3 w-3' />
|
||||
<div className='ml-1 max-[1280px]:hidden'>{t('common.environment.development')}</div>
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ const Installed: FC<Props> = ({
|
|||
useEffect(() => {
|
||||
if (hasInstalled && uniqueIdentifier === installedInfoPayload.uniqueIdentifier)
|
||||
onInstalled()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [hasInstalled])
|
||||
|
||||
const [isInstalling, setIsInstalling] = React.useState(false)
|
||||
|
|
@ -105,12 +104,12 @@ const Installed: FC<Props> = ({
|
|||
}
|
||||
}
|
||||
|
||||
const { langeniusVersionInfo } = useAppContext()
|
||||
const { langGeniusVersionInfo } = useAppContext()
|
||||
const isDifyVersionCompatible = useMemo(() => {
|
||||
if (!langeniusVersionInfo.current_version)
|
||||
if (!langGeniusVersionInfo.current_version)
|
||||
return true
|
||||
return gte(langeniusVersionInfo.current_version, payload.meta.minimum_dify_version ?? '0.0.0')
|
||||
}, [langeniusVersionInfo.current_version, payload.meta.minimum_dify_version])
|
||||
return gte(langGeniusVersionInfo.current_version, payload.meta.minimum_dify_version ?? '0.0.0')
|
||||
}, [langGeniusVersionInfo.current_version, payload.meta.minimum_dify_version])
|
||||
|
||||
return (
|
||||
<>
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ const Installed: FC<Props> = ({
|
|||
useEffect(() => {
|
||||
if (hasInstalled && uniqueIdentifier === installedInfoPayload.uniqueIdentifier)
|
||||
onInstalled()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [hasInstalled])
|
||||
|
||||
const handleCancel = () => {
|
||||
|
|
@ -120,12 +119,12 @@ const Installed: FC<Props> = ({
|
|||
}
|
||||
}
|
||||
|
||||
const { langeniusVersionInfo } = useAppContext()
|
||||
const { langGeniusVersionInfo } = useAppContext()
|
||||
const { data: pluginDeclaration } = usePluginDeclarationFromMarketPlace(uniqueIdentifier)
|
||||
const isDifyVersionCompatible = useMemo(() => {
|
||||
if (!pluginDeclaration || !langeniusVersionInfo.current_version) return true
|
||||
return gte(langeniusVersionInfo.current_version, pluginDeclaration?.manifest.meta.minimum_dify_version ?? '0.0.0')
|
||||
}, [langeniusVersionInfo.current_version, pluginDeclaration])
|
||||
if (!pluginDeclaration || !langGeniusVersionInfo.current_version) return true
|
||||
return gte(langGeniusVersionInfo.current_version, pluginDeclaration?.manifest.meta.minimum_dify_version ?? '0.0.0')
|
||||
}, [langGeniusVersionInfo.current_version, pluginDeclaration])
|
||||
|
||||
const { canInstall } = useInstallPluginLimit({ ...payload, from: 'marketplace' })
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -265,7 +265,7 @@ const ToolSelector: FC<Props> = ({
|
|||
/>
|
||||
)}
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent>
|
||||
<PortalToFollowElemContent className='z-10'>
|
||||
<div className={cn('relative max-h-[642px] min-h-20 w-[361px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur pb-4 shadow-lg backdrop-blur-sm', 'overflow-y-auto pb-2')}>
|
||||
<>
|
||||
<div className='system-xl-semibold px-4 pb-1 pt-3.5 text-text-primary'>{t(`plugin.detailPanel.toolSelector.${isEdit ? 'toolSetting' : 'title'}`)}</div>
|
||||
|
|
@ -309,15 +309,15 @@ const ToolSelector: FC<Props> = ({
|
|||
{currentProvider && currentProvider.type === CollectionType.builtIn && currentProvider.allow_delete && (
|
||||
<>
|
||||
<Divider className='my-1 w-full' />
|
||||
<div className='px-4 py-2'>
|
||||
<PluginAuthInAgent
|
||||
pluginPayload={{
|
||||
provider: currentProvider.name,
|
||||
category: AuthCategory.tool,
|
||||
}}
|
||||
credentialId={value?.credential_id}
|
||||
onAuthorizationItemClick={handleAuthorizationItemClick}
|
||||
/>
|
||||
<div className='px-4 py-2'>
|
||||
<PluginAuthInAgent
|
||||
pluginPayload={{
|
||||
provider: currentProvider.name,
|
||||
category: AuthCategory.tool,
|
||||
}}
|
||||
credentialId={value?.credential_id}
|
||||
onAuthorizationItemClick={handleAuthorizationItemClick}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -62,13 +62,13 @@ const PluginItem: FC<Props> = ({
|
|||
return [PluginSource.github, PluginSource.marketplace].includes(source) ? author : ''
|
||||
}, [source, author])
|
||||
|
||||
const { langeniusVersionInfo } = useAppContext()
|
||||
const { langGeniusVersionInfo } = useAppContext()
|
||||
|
||||
const isDifyVersionCompatible = useMemo(() => {
|
||||
if (!langeniusVersionInfo.current_version)
|
||||
if (!langGeniusVersionInfo.current_version)
|
||||
return true
|
||||
return gte(langeniusVersionInfo.current_version, declarationMeta.minimum_dify_version ?? '0.0.0')
|
||||
}, [declarationMeta.minimum_dify_version, langeniusVersionInfo.current_version])
|
||||
return gte(langGeniusVersionInfo.current_version, declarationMeta.minimum_dify_version ?? '0.0.0')
|
||||
}, [declarationMeta.minimum_dify_version, langGeniusVersionInfo.current_version])
|
||||
|
||||
const handleDelete = () => {
|
||||
refreshPluginList({ category } as any)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ import * as Sentry from '@sentry/react'
|
|||
|
||||
const isDevelopment = process.env.NODE_ENV === 'development'
|
||||
|
||||
const SentryInit = ({
|
||||
const SentryInitializer = ({
|
||||
children,
|
||||
}: { children: React.ReactNode }) => {
|
||||
}: { children: React.ReactElement }) => {
|
||||
useEffect(() => {
|
||||
const SENTRY_DSN = document?.body?.getAttribute('data-public-sentry-dsn')
|
||||
if (!isDevelopment && SENTRY_DSN) {
|
||||
|
|
@ -26,4 +26,4 @@ const SentryInit = ({
|
|||
return children
|
||||
}
|
||||
|
||||
export default SentryInit
|
||||
export default SentryInitializer
|
||||
|
|
@ -10,12 +10,12 @@ import {
|
|||
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
|
||||
} from '@/app/education-apply/constants'
|
||||
|
||||
type SwrInitorProps = {
|
||||
type SwrInitializerProps = {
|
||||
children: ReactNode
|
||||
}
|
||||
const SwrInitor = ({
|
||||
const SwrInitializer = ({
|
||||
children,
|
||||
}: SwrInitorProps) => {
|
||||
}: SwrInitializerProps) => {
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const consoleToken = decodeURIComponent(searchParams.get('access_token') || '')
|
||||
|
|
@ -86,4 +86,4 @@ const SwrInitor = ({
|
|||
: null
|
||||
}
|
||||
|
||||
export default SwrInitor
|
||||
export default SwrInitializer
|
||||
|
|
@ -697,7 +697,7 @@ const getIterationItemType = ({
|
|||
case VarType.arrayObject:
|
||||
return VarType.object
|
||||
case VarType.array:
|
||||
return VarType.any
|
||||
return VarType.arrayObject // Use more specific type instead of any
|
||||
case VarType.arrayFile:
|
||||
return VarType.file
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ const metaData = genNodeMetaData({
|
|||
const nodeDefault: NodeDefault<AgentNodeType> = {
|
||||
metaData,
|
||||
defaultValue: {
|
||||
version: '2',
|
||||
tool_node_version: '2',
|
||||
},
|
||||
checkValid(payload, t, moreDataForCheckValid: {
|
||||
strategyProvider?: StrategyPluginDetail,
|
||||
|
|
@ -58,27 +58,29 @@ const nodeDefault: NodeDefault<AgentNodeType> = {
|
|||
const userSettings = toolValue.settings
|
||||
const reasoningConfig = toolValue.parameters
|
||||
const version = payload.version
|
||||
const toolNodeVersion = payload.tool_node_version
|
||||
const mergeVersion = version || toolNodeVersion
|
||||
schemas.forEach((schema: any) => {
|
||||
if (schema?.required) {
|
||||
if (schema.form === 'form' && !version && !userSettings[schema.name]?.value) {
|
||||
if (schema.form === 'form' && !mergeVersion && !userSettings[schema.name]?.value) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('workflow.errorMsg.toolParameterRequired', { field: renderI18nObject(param.label, language), param: renderI18nObject(schema.label, language) }),
|
||||
}
|
||||
}
|
||||
if (schema.form === 'form' && version && !userSettings[schema.name]?.value.value) {
|
||||
if (schema.form === 'form' && mergeVersion && !userSettings[schema.name]?.value.value) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('workflow.errorMsg.toolParameterRequired', { field: renderI18nObject(param.label, language), param: renderI18nObject(schema.label, language) }),
|
||||
}
|
||||
}
|
||||
if (schema.form === 'llm' && !version && reasoningConfig[schema.name].auto === 0 && !reasoningConfig[schema.name]?.value) {
|
||||
if (schema.form === 'llm' && !mergeVersion && reasoningConfig[schema.name].auto === 0 && !reasoningConfig[schema.name]?.value) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('workflow.errorMsg.toolParameterRequired', { field: renderI18nObject(param.label, language), param: renderI18nObject(schema.label, language) }),
|
||||
}
|
||||
}
|
||||
if (schema.form === 'llm' && version && reasoningConfig[schema.name].auto === 0 && !reasoningConfig[schema.name]?.value.value) {
|
||||
if (schema.form === 'llm' && mergeVersion && reasoningConfig[schema.name].auto === 0 && !reasoningConfig[schema.name]?.value.value) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('workflow.errorMsg.toolParameterRequired', { field: renderI18nObject(param.label, language), param: renderI18nObject(schema.label, language) }),
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ export type AgentNodeType = CommonNodeType & {
|
|||
plugin_unique_identifier?: string
|
||||
memory?: Memory
|
||||
version?: string
|
||||
tool_node_version?: string
|
||||
}
|
||||
|
||||
export enum AgentFeature {
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ const useConfig = (id: string, payload: AgentNodeType) => {
|
|||
}
|
||||
|
||||
const formattingLegacyData = () => {
|
||||
if (inputs.version)
|
||||
if (inputs.version || inputs.tool_node_version)
|
||||
return inputs
|
||||
const newData = produce(inputs, (draft) => {
|
||||
const schemas = currentStrategy?.parameters || []
|
||||
|
|
@ -140,7 +140,7 @@ const useConfig = (id: string, payload: AgentNodeType) => {
|
|||
if (targetSchema?.type === FormTypeEnum.multiToolSelector)
|
||||
draft.agent_parameters![key].value = draft.agent_parameters![key].value.map((tool: any) => formattingToolData(tool))
|
||||
})
|
||||
draft.version = '2'
|
||||
draft.tool_node_version = '2'
|
||||
})
|
||||
return newData
|
||||
}
|
||||
|
|
@ -151,7 +151,6 @@ const useConfig = (id: string, payload: AgentNodeType) => {
|
|||
return
|
||||
const newData = formattingLegacyData()
|
||||
setInputs(newData)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentStrategy])
|
||||
|
||||
// vars
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ const nodeDefault: NodeDefault<ToolNodeType> = {
|
|||
defaultValue: {
|
||||
tool_parameters: {},
|
||||
tool_configurations: {},
|
||||
version: '2',
|
||||
tool_node_version: '2',
|
||||
},
|
||||
checkValid(payload: ToolNodeType, t: any, moreDataForCheckValid: any) {
|
||||
const { toolInputsSchema, toolSettingSchema, language, notAuthed } = moreDataForCheckValid
|
||||
|
|
|
|||
|
|
@ -23,4 +23,5 @@ export type ToolNodeType = CommonNodeType & {
|
|||
output_schema: Record<string, any>
|
||||
paramSchemas?: Record<string, any>[]
|
||||
version?: string
|
||||
tool_node_version?: string
|
||||
}
|
||||
|
|
|
|||
|
|
@ -286,8 +286,8 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => {
|
|||
}
|
||||
}
|
||||
|
||||
if (node.data.type === BlockEnum.Tool && !(node as Node<ToolNodeType>).data.version) {
|
||||
(node as Node<ToolNodeType>).data.version = '2'
|
||||
if (node.data.type === BlockEnum.Tool && !(node as Node<ToolNodeType>).data.version && !(node as Node<ToolNodeType>).data.tool_node_version) {
|
||||
(node as Node<ToolNodeType>).data.tool_node_version = '2'
|
||||
|
||||
const toolConfigurations = (node as Node<ToolNodeType>).data.tool_configurations
|
||||
if (toolConfigurations && Object.keys(toolConfigurations).length > 0) {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import RoutePrefixHandle from './routePrefixHandle'
|
||||
import type { Viewport } from 'next'
|
||||
import I18nServer from './components/i18n-server'
|
||||
import BrowserInitor from './components/browser-initor'
|
||||
import SentryInitor from './components/sentry-initor'
|
||||
import BrowserInitializer from './components/browser-initializer'
|
||||
import SentryInitializer from './components/sentry-initializer'
|
||||
import { getLocaleOnServer } from '@/i18n/server'
|
||||
import { TanstackQueryIniter } from '@/context/query-client'
|
||||
import { TanstackQueryInitializer } from '@/context/query-client'
|
||||
import { ThemeProvider } from 'next-themes'
|
||||
import './styles/globals.css'
|
||||
import './styles/markdown.scss'
|
||||
|
|
@ -62,9 +62,9 @@ const LocaleLayout = async ({
|
|||
className="color-scheme h-full select-auto"
|
||||
{...datasetMap}
|
||||
>
|
||||
<BrowserInitor>
|
||||
<SentryInitor>
|
||||
<TanstackQueryIniter>
|
||||
<BrowserInitializer>
|
||||
<SentryInitializer>
|
||||
<TanstackQueryInitializer>
|
||||
<ThemeProvider
|
||||
attribute='data-theme'
|
||||
defaultTheme='system'
|
||||
|
|
@ -77,9 +77,9 @@ const LocaleLayout = async ({
|
|||
</GlobalPublicStoreProvider>
|
||||
</I18nServer>
|
||||
</ThemeProvider>
|
||||
</TanstackQueryIniter>
|
||||
</SentryInitor>
|
||||
</BrowserInitor>
|
||||
</TanstackQueryInitializer>
|
||||
</SentryInitializer>
|
||||
</BrowserInitializer>
|
||||
<RoutePrefixHandle />
|
||||
</body>
|
||||
</html>
|
||||
|
|
|
|||
|
|
@ -1,20 +1,15 @@
|
|||
'use client'
|
||||
|
||||
import { createRef, useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import useSWR from 'swr'
|
||||
import { createContext, useContext, useContextSelector } from 'use-context-selector'
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import { fetchAppList } from '@/service/apps'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { fetchCurrentWorkspace, fetchLanggeniusVersion, fetchUserProfile } from '@/service/common'
|
||||
import type { App } from '@/types/app'
|
||||
import { fetchCurrentWorkspace, fetchLangGeniusVersion, fetchUserProfile } from '@/service/common'
|
||||
import type { ICurrentWorkspace, LangGeniusVersionResponse, UserProfileResponse } from '@/models/common'
|
||||
import MaintenanceNotice from '@/app/components/header/maintenance-notice'
|
||||
import { noop } from 'lodash-es'
|
||||
|
||||
export type AppContextValue = {
|
||||
apps: App[]
|
||||
mutateApps: VoidFunction
|
||||
userProfile: UserProfileResponse
|
||||
mutateUserProfile: VoidFunction
|
||||
currentWorkspace: ICurrentWorkspace
|
||||
|
|
@ -23,13 +18,21 @@ export type AppContextValue = {
|
|||
isCurrentWorkspaceEditor: boolean
|
||||
isCurrentWorkspaceDatasetOperator: boolean
|
||||
mutateCurrentWorkspace: VoidFunction
|
||||
pageContainerRef: React.RefObject<HTMLDivElement>
|
||||
langeniusVersionInfo: LangGeniusVersionResponse
|
||||
langGeniusVersionInfo: LangGeniusVersionResponse
|
||||
useSelector: typeof useSelector
|
||||
isLoadingCurrentWorkspace: boolean
|
||||
}
|
||||
|
||||
const initialLangeniusVersionInfo = {
|
||||
const userProfilePlaceholder = {
|
||||
id: '',
|
||||
name: '',
|
||||
email: '',
|
||||
avatar: '',
|
||||
avatar_url: '',
|
||||
is_password_set: false,
|
||||
}
|
||||
|
||||
const initialLangGeniusVersionInfo = {
|
||||
current_env: '',
|
||||
current_version: '',
|
||||
latest_version: '',
|
||||
|
|
@ -50,16 +53,7 @@ const initialWorkspaceInfo: ICurrentWorkspace = {
|
|||
}
|
||||
|
||||
const AppContext = createContext<AppContextValue>({
|
||||
apps: [],
|
||||
mutateApps: noop,
|
||||
userProfile: {
|
||||
id: '',
|
||||
name: '',
|
||||
email: '',
|
||||
avatar: '',
|
||||
avatar_url: '',
|
||||
is_password_set: false,
|
||||
},
|
||||
userProfile: userProfilePlaceholder,
|
||||
currentWorkspace: initialWorkspaceInfo,
|
||||
isCurrentWorkspaceManager: false,
|
||||
isCurrentWorkspaceOwner: false,
|
||||
|
|
@ -67,8 +61,7 @@ const AppContext = createContext<AppContextValue>({
|
|||
isCurrentWorkspaceDatasetOperator: false,
|
||||
mutateUserProfile: noop,
|
||||
mutateCurrentWorkspace: noop,
|
||||
pageContainerRef: createRef(),
|
||||
langeniusVersionInfo: initialLangeniusVersionInfo,
|
||||
langGeniusVersionInfo: initialLangGeniusVersionInfo,
|
||||
useSelector,
|
||||
isLoadingCurrentWorkspace: false,
|
||||
})
|
||||
|
|
@ -82,14 +75,11 @@ export type AppContextProviderProps = {
|
|||
}
|
||||
|
||||
export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) => {
|
||||
const pageContainerRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const { data: appList, mutate: mutateApps } = useSWR({ url: '/apps', params: { page: 1, limit: 30, name: '' } }, fetchAppList)
|
||||
const { data: userProfileResponse, mutate: mutateUserProfile } = useSWR({ url: '/account/profile', params: {} }, fetchUserProfile)
|
||||
const { data: currentWorkspaceResponse, mutate: mutateCurrentWorkspace, isLoading: isLoadingCurrentWorkspace } = useSWR({ url: '/workspaces/current', params: {} }, fetchCurrentWorkspace)
|
||||
|
||||
const [userProfile, setUserProfile] = useState<UserProfileResponse>()
|
||||
const [langeniusVersionInfo, setLangeniusVersionInfo] = useState<LangGeniusVersionResponse>(initialLangeniusVersionInfo)
|
||||
const [userProfile, setUserProfile] = useState<UserProfileResponse>(userProfilePlaceholder)
|
||||
const [langGeniusVersionInfo, setLangGeniusVersionInfo] = useState<LangGeniusVersionResponse>(initialLangGeniusVersionInfo)
|
||||
const [currentWorkspace, setCurrentWorkspace] = useState<ICurrentWorkspace>(initialWorkspaceInfo)
|
||||
const isCurrentWorkspaceManager = useMemo(() => ['owner', 'admin'].includes(currentWorkspace.role), [currentWorkspace.role])
|
||||
const isCurrentWorkspaceOwner = useMemo(() => currentWorkspace.role === 'owner', [currentWorkspace.role])
|
||||
|
|
@ -101,8 +91,8 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
|
|||
setUserProfile(result)
|
||||
const current_version = userProfileResponse.headers.get('x-version')
|
||||
const current_env = process.env.NODE_ENV === 'development' ? 'DEVELOPMENT' : userProfileResponse.headers.get('x-env')
|
||||
const versionData = await fetchLanggeniusVersion({ url: '/version', params: { current_version } })
|
||||
setLangeniusVersionInfo({ ...versionData, current_version, latest_version: versionData.version, current_env })
|
||||
const versionData = await fetchLangGeniusVersion({ url: '/version', params: { current_version } })
|
||||
setLangGeniusVersionInfo({ ...versionData, current_version, latest_version: versionData.version, current_env })
|
||||
}
|
||||
}, [userProfileResponse])
|
||||
|
||||
|
|
@ -115,17 +105,11 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
|
|||
setCurrentWorkspace(currentWorkspaceResponse)
|
||||
}, [currentWorkspaceResponse])
|
||||
|
||||
if (!appList || !userProfile)
|
||||
return <Loading type='app' />
|
||||
|
||||
return (
|
||||
<AppContext.Provider value={{
|
||||
apps: appList.data,
|
||||
mutateApps,
|
||||
userProfile,
|
||||
mutateUserProfile,
|
||||
pageContainerRef,
|
||||
langeniusVersionInfo,
|
||||
langGeniusVersionInfo,
|
||||
useSelector,
|
||||
currentWorkspace,
|
||||
isCurrentWorkspaceManager,
|
||||
|
|
@ -137,7 +121,7 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
|
|||
}}>
|
||||
<div className='flex h-full flex-col overflow-y-auto'>
|
||||
{globalThis.document?.body?.getAttribute('data-public-maintenance-notice') && <MaintenanceNotice />}
|
||||
<div ref={pageContainerRef} className='relative flex grow flex-col overflow-y-auto overflow-x-hidden bg-background-body'>
|
||||
<div className='relative flex grow flex-col overflow-y-auto overflow-x-hidden bg-background-body'>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ const client = new QueryClient({
|
|||
},
|
||||
})
|
||||
|
||||
export const TanstackQueryIniter: FC<PropsWithChildren> = (props) => {
|
||||
export const TanstackQueryInitializer: FC<PropsWithChildren> = (props) => {
|
||||
const { children } = props
|
||||
return <QueryClientProvider client={client}>
|
||||
{children}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import storybook from 'eslint-plugin-storybook'
|
|||
import tailwind from 'eslint-plugin-tailwindcss'
|
||||
import reactHooks from 'eslint-plugin-react-hooks'
|
||||
import sonar from 'eslint-plugin-sonarjs'
|
||||
import oxlint from 'eslint-plugin-oxlint'
|
||||
|
||||
// import reactRefresh from 'eslint-plugin-react-refresh'
|
||||
|
||||
|
|
@ -245,4 +246,5 @@ export default combine(
|
|||
'tailwindcss/migration-from-tailwind-2': 'warn',
|
||||
},
|
||||
},
|
||||
oxlint.configs['flat/recommended'],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ const translation = {
|
|||
optional: 'Wahlfrei',
|
||||
noTemplateFound: 'Keine Vorlagen gefunden',
|
||||
workflowUserDescription: 'Autonome KI-Arbeitsabläufe visuell per Drag-and-Drop erstellen.',
|
||||
foundResults: '{{Anzahl}} Befund',
|
||||
foundResults: '{{count}} Befund',
|
||||
chatbotShortDescription: 'LLM-basierter Chatbot mit einfacher Einrichtung',
|
||||
completionUserDescription: 'Erstellen Sie schnell einen KI-Assistenten für Textgenerierungsaufgaben mit einfacher Konfiguration.',
|
||||
noAppsFound: 'Keine Apps gefunden',
|
||||
|
|
@ -92,7 +92,7 @@ const translation = {
|
|||
noTemplateFoundTip: 'Versuchen Sie, mit verschiedenen Schlüsselwörtern zu suchen.',
|
||||
advancedUserDescription: 'Workflow mit Speicherfunktionen und Chatbot-Oberfläche.',
|
||||
chatbotUserDescription: 'Erstellen Sie schnell einen LLM-basierten Chatbot mit einfacher Konfiguration. Sie können später zu Chatflow wechseln.',
|
||||
foundResult: '{{Anzahl}} Ergebnis',
|
||||
foundResult: '{{count}} Ergebnis',
|
||||
agentUserDescription: 'Ein intelligenter Agent, der in der Lage ist, iteratives Denken zu führen und autonome Werkzeuge zu verwenden, um Aufgabenziele zu erreichen.',
|
||||
agentShortDescription: 'Intelligenter Agent mit logischem Denken und autonomer Werkzeugnutzung',
|
||||
dropDSLToCreateApp: 'Ziehen Sie die DSL-Datei hierher, um die App zu erstellen',
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue