Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly 2024-11-08 13:48:23 +08:00
commit 14a723a2a4
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
53 changed files with 1132 additions and 157 deletions

View File

@ -6,6 +6,7 @@ on:
- main - main
paths: paths:
- api/migrations/** - api/migrations/**
- .github/workflows/db-migration-test.yml
concurrency: concurrency:
group: db-migration-test-${{ github.ref }} group: db-migration-test-${{ github.ref }}

View File

@ -285,8 +285,9 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model Configuration # Model configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64 MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512 PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024 CODE_GENERATION_MAX_TOKENS=1024
@ -324,10 +325,10 @@ UNSTRUCTURED_API_KEY=
SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL= SSRF_PROXY_HTTPS_URL=
SSRF_DEFAULT_MAX_RETRIES=3 SSRF_DEFAULT_MAX_RETRIES=3
SSRF_DEFAULT_TIME_OUT= SSRF_DEFAULT_TIME_OUT=5
SSRF_DEFAULT_CONNECT_TIME_OUT= SSRF_DEFAULT_CONNECT_TIME_OUT=5
SSRF_DEFAULT_READ_TIME_OUT= SSRF_DEFAULT_READ_TIME_OUT=5
SSRF_DEFAULT_WRITE_TIME_OUT= SSRF_DEFAULT_WRITE_TIME_OUT=5
BATCH_UPLOAD_LIMIT=10 BATCH_UPLOAD_LIMIT=10
KEYWORD_DATA_SOURCE_TYPE=database KEYWORD_DATA_SOURCE_TYPE=database

View File

@ -2,7 +2,7 @@ import os
from configs import dify_config from configs import dify_config
if os.environ.get("DEBUG", "false").lower() != "true": if not dify_config.DEBUG:
from gevent import monkey from gevent import monkey
monkey.patch_all() monkey.patch_all()

View File

@ -1,6 +1,8 @@
import os import os
if os.environ.get("DEBUG", "false").lower() != "true": from configs import dify_config
if not dify_config.DEBUG:
from gevent import monkey from gevent import monkey
monkey.patch_all() monkey.patch_all()

View File

@ -329,6 +329,16 @@ class HttpConfig(BaseSettings):
default=1 * 1024 * 1024, default=1 * 1024 * 1024,
) )
SSRF_DEFAULT_MAX_RETRIES: PositiveInt = Field(
description="Maximum number of retries for network requests (SSRF)",
default=3,
)
SSRF_PROXY_ALL_URL: Optional[str] = Field(
description="Proxy URL for HTTP or HTTPS requests to prevent Server-Side Request Forgery (SSRF)",
default=None,
)
SSRF_PROXY_HTTP_URL: Optional[str] = Field( SSRF_PROXY_HTTP_URL: Optional[str] = Field(
description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)", description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)",
default=None, default=None,
@ -677,12 +687,17 @@ class IndexingConfig(BaseSettings):
) )
class ImageFormatConfig(BaseSettings): class VisionFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field( MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64", description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64", default="base64",
) )
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
class CeleryBeatConfig(BaseSettings): class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field( CELERY_BEAT_SCHEDULER_TIME: int = Field(
@ -787,7 +802,7 @@ class FeatureConfig(
FileAccessConfig, FileAccessConfig,
FileUploadConfig, FileUploadConfig,
HttpConfig, HttpConfig,
ImageFormatConfig, VisionFormatConfig,
InnerAPIConfig, InnerAPIConfig,
IndexingConfig, IndexingConfig,
LoggingConfig, LoggingConfig,

View File

@ -956,7 +956,7 @@ class DocumentRetryApi(DocumentResource):
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
retry_documents.append(document) retry_documents.append(document)
except Exception as e: except Exception as e:
logging.error(f"Document {document_id} retry failed: {str(e)}") logging.exception(f"Document {document_id} retry failed: {str(e)}")
continue continue
# retry document # retry document
DocumentService.retry_document(dataset_id, retry_documents) DocumentService.retry_document(dataset_id, retry_documents)

View File

@ -7,7 +7,11 @@ from controllers.service_api import api
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import (
conversation_delete_fields,
conversation_infinite_scroll_pagination_fields,
simple_conversation_fields,
)
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
@ -49,7 +53,7 @@ class ConversationApi(Resource):
class ConversationDetailApi(Resource): class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields) @marshal_with(conversation_delete_fields)
def delete(self, app_model: App, end_user: EndUser, c_id): def delete(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -58,10 +62,9 @@ class ConversationDetailApi(Resource):
conversation_id = str(c_id) conversation_id = str(c_id)
try: try:
ConversationService.delete(app_model, conversation_id, end_user) return ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 200
class ConversationRenameApi(Resource): class ConversationRenameApi(Resource):

View File

@ -1,6 +1,5 @@
import contextvars import contextvars
import logging import logging
import os
import threading import threading
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
@ -10,6 +9,7 @@ from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
import contexts import contexts
from configs import dify_config
from constants import UUID_NIL from constants import UUID_NIL
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
@ -328,7 +328,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG", "false").lower() == "true": if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

View File

@ -242,7 +242,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_listener_time = time.time() start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e: except Exception as e:
logger.error(e) logger.exception(e)
break break
if tts_publisher: if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id) yield MessageAudioEndStreamResponse(audio="", task_id=task_id)

View File

@ -1,5 +1,4 @@
import logging import logging
import os
import threading import threading
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from configs import dify_config
from constants import UUID_NIL from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -235,7 +235,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

View File

@ -1,5 +1,4 @@
import logging import logging
import os
import threading import threading
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from configs import dify_config
from constants import UUID_NIL from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

View File

@ -1,5 +1,4 @@
import logging import logging
import os
import threading import threading
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
@ -213,7 +213,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

View File

@ -1,6 +1,5 @@
import contextvars import contextvars
import logging import logging
import os
import threading import threading
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
@ -10,6 +9,7 @@ from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
import contexts import contexts
from configs import dify_config
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
@ -273,7 +273,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true": if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

View File

@ -216,7 +216,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else: else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e: except Exception as e:
logger.error(e) logger.exception(e)
break break
if tts_publisher: if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id) yield MessageAudioEndStreamResponse(audio="", task_id=task_id)

View File

@ -3,7 +3,7 @@ import base64
from configs import dify_config from configs import dify_config
from core.file import file_repository from core.file import file_repository
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
@ -71,6 +71,12 @@ def to_prompt_message_content(f: File, /):
if f.extension is None: if f.extension is None:
raise ValueError("Missing file extension") raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case _: case _:
raise ValueError(f"file type {f.type} is not supported") raise ValueError(f"file type {f.type} is not supported")
@ -112,7 +118,7 @@ def _download_file_content(path: str, /):
def _get_encoded_string(f: File, /): def _get_encoded_string(f: File, /):
match f.transfer_method: match f.transfer_method:
case FileTransferMethod.REMOTE_URL: case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url) response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status() response.raise_for_status()
content = response.content content = response.content
encoded_string = base64.b64encode(content).decode("utf-8") encoded_string = base64.b64encode(content).decode("utf-8")
@ -140,6 +146,8 @@ def _file_to_encoded_string(f: File, /):
match f.type: match f.type:
case FileType.IMAGE: case FileType.IMAGE:
return _to_base64_data_string(f) return _to_base64_data_string(f)
case FileType.VIDEO:
return _to_base64_data_string(f)
case FileType.AUDIO: case FileType.AUDIO:
return _get_encoded_string(f) return _get_encoded_string(f)
case _: case _:

View File

@ -3,26 +3,20 @@ Proxy requests to avoid SSRF
""" """
import logging import logging
import os
import time import time
import httpx import httpx
SSRF_PROXY_ALL_URL = os.getenv("SSRF_PROXY_ALL_URL", "") from configs import dify_config
SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "") SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
SSRF_DEFAULT_TIME_OUT = float(os.getenv("SSRF_DEFAULT_TIME_OUT", "5"))
SSRF_DEFAULT_CONNECT_TIME_OUT = float(os.getenv("SSRF_DEFAULT_CONNECT_TIME_OUT", "5"))
SSRF_DEFAULT_READ_TIME_OUT = float(os.getenv("SSRF_DEFAULT_READ_TIME_OUT", "5"))
SSRF_DEFAULT_WRITE_TIME_OUT = float(os.getenv("SSRF_DEFAULT_WRITE_TIME_OUT", "5"))
proxy_mounts = ( proxy_mounts = (
{ {
"http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL), "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
"https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL), "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
} }
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL
else None else None
) )
@ -38,17 +32,17 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "timeout" not in kwargs: if "timeout" not in kwargs:
kwargs["timeout"] = httpx.Timeout( kwargs["timeout"] = httpx.Timeout(
SSRF_DEFAULT_TIME_OUT, timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=SSRF_DEFAULT_CONNECT_TIME_OUT, connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=SSRF_DEFAULT_READ_TIME_OUT, read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=SSRF_DEFAULT_WRITE_TIME_OUT, write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
) )
retries = 0 retries = 0
while retries <= max_retries: while retries <= max_retries:
try: try:
if SSRF_PROXY_ALL_URL: if dify_config.SSRF_PROXY_ALL_URL:
with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client: with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client:
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
elif proxy_mounts: elif proxy_mounts:
with httpx.Client(mounts=proxy_mounts) as client: with httpx.Client(mounts=proxy_mounts) as client:

View File

@ -1,8 +1,8 @@
import logging import logging
import os
from collections.abc import Callable, Generator, Iterable, Sequence from collections.abc import Callable, Generator, Iterable, Sequence
from typing import IO, Any, Literal, Optional, Union, cast, overload from typing import IO, Any, Literal, Optional, Union, cast, overload
from configs import dify_config
from core.entities.embedding_type import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.entities.provider_entities import ModelLoadBalancingConfiguration
@ -509,7 +509,7 @@ class LBModelManager:
continue continue
if bool(os.environ.get("DEBUG", "False").lower() == "true"): if dify_config.DEBUG:
logger.info( logger.info(
f"Model LB\nid: {config.id}\nname:{config.name}\n" f"Model LB\nid: {config.id}\nname:{config.name}\n"
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"

View File

@ -12,11 +12,13 @@ from .message_entities import (
TextPromptMessageContent, TextPromptMessageContent,
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
VideoPromptMessageContent,
) )
from .model_entities import ModelPropertyKey from .model_entities import ModelPropertyKey
__all__ = [ __all__ = [
"ImagePromptMessageContent", "ImagePromptMessageContent",
"VideoPromptMessageContent",
"PromptMessage", "PromptMessage",
"PromptMessageRole", "PromptMessageRole",
"LLMUsage", "LLMUsage",

View File

@ -56,6 +56,7 @@ class PromptMessageContentType(Enum):
TEXT = "text" TEXT = "text"
IMAGE = "image" IMAGE = "image"
AUDIO = "audio" AUDIO = "audio"
VIDEO = "video"
class PromptMessageContent(BaseModel): class PromptMessageContent(BaseModel):
@ -75,6 +76,12 @@ class TextPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.TEXT type: PromptMessageContentType = PromptMessageContentType.TEXT
class VideoPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")
class AudioPromptMessageContent(PromptMessageContent): class AudioPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data") data: str = Field(..., description="Base64 encoded audio data")

View File

@ -126,6 +126,6 @@ class OutputModeration(BaseModel):
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result return result
except Exception as e: except Exception as e:
logger.error("Moderation Output error: %s", e) logger.exception("Moderation Output error: %s", e)
return None return None

View File

@ -708,7 +708,7 @@ class TraceQueueManager:
trace_task.app_id = self.app_id trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task) trace_manager_queue.put(trace_task)
except Exception as e: except Exception as e:
logging.error(f"Error adding trace task: {e}") logging.exception(f"Error adding trace task: {e}")
finally: finally:
self.start_timer() self.start_timer()
@ -727,7 +727,7 @@ class TraceQueueManager:
if tasks: if tasks:
self.send_to_celery(tasks) self.send_to_celery(tasks)
except Exception as e: except Exception as e:
logging.error(f"Error processing trace tasks: {e}") logging.exception(f"Error processing trace tasks: {e}")
def start_timer(self): def start_timer(self):
global trace_manager_timer global trace_manager_timer

View File

@ -242,7 +242,7 @@ class CouchbaseVector(BaseVector):
try: try:
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute() self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
except Exception as e: except Exception as e:
logger.error(e) logger.exception(e)
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
query = f""" query = f"""

View File

@ -79,7 +79,7 @@ class LindormVectorStore(BaseVector):
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False) existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e: except Exception as e:
logger.error(f"Error fetching batch {batch_ids}: {e}") logger.exception(f"Error fetching batch {batch_ids}: {e}")
return set() return set()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) @retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
@ -96,7 +96,7 @@ class LindormVectorStore(BaseVector):
) )
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e: except Exception as e:
logger.error(f"Error fetching batch {batch_ids}: {e}") logger.exception(f"Error fetching batch {batch_ids}: {e}")
return set() return set()
if ids is None: if ids is None:
@ -177,7 +177,7 @@ class LindormVectorStore(BaseVector):
else: else:
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
except Exception as e: except Exception as e:
logger.error(f"Error occurred while deleting the index: {e}") logger.exception(f"Error occurred while deleting the index: {e}")
raise e raise e
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
@ -201,7 +201,7 @@ class LindormVectorStore(BaseVector):
try: try:
response = self._client.search(index=self._collection_name, body=query) response = self._client.search(index=self._collection_name, body=query)
except Exception as e: except Exception as e:
logger.error(f"Error executing search: {e}") logger.exception(f"Error executing search: {e}")
raise raise
docs_and_scores = [] docs_and_scores = []

View File

@ -86,7 +86,7 @@ class MilvusVector(BaseVector):
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids) pks.extend(ids)
except MilvusException as e: except MilvusException as e:
logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count) logger.exception("Failed to insert batch starting at entity: %s/%s", i, total_count)
raise e raise e
return pks return pks

View File

@ -142,7 +142,7 @@ class MyScaleVector(BaseVector):
for r in self._client.query(sql).named_results() for r in self._client.query(sql).named_results()
] ]
except Exception as e: except Exception as e:
logging.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return [] return []
def delete(self) -> None: def delete(self) -> None:

View File

@ -129,7 +129,7 @@ class OpenSearchVector(BaseVector):
if status == 404: if status == 404:
logger.warning(f"Document not found for deletion: {doc_id}") logger.warning(f"Document not found for deletion: {doc_id}")
else: else:
logger.error(f"Error deleting document: {error}") logger.exception(f"Error deleting document: {error}")
def delete(self) -> None: def delete(self) -> None:
self._client.indices.delete(index=self._collection_name.lower()) self._client.indices.delete(index=self._collection_name.lower())
@ -158,7 +158,7 @@ class OpenSearchVector(BaseVector):
try: try:
response = self._client.search(index=self._collection_name.lower(), body=query) response = self._client.search(index=self._collection_name.lower(), body=query)
except Exception as e: except Exception as e:
logger.error(f"Error executing search: {e}") logger.exception(f"Error executing search: {e}")
raise raise
docs = [] docs = []

View File

@ -89,7 +89,7 @@ class CacheEmbedding(Embeddings):
db.session.rollback() db.session.rollback()
except Exception as ex: except Exception as ex:
db.session.rollback() db.session.rollback()
logger.error("Failed to embed documents: %s", ex) logger.exception("Failed to embed documents: %s", ex)
raise ex raise ex
return text_embeddings return text_embeddings

View File

@ -28,7 +28,6 @@ logger = logging.getLogger(__name__)
class WordExtractor(BaseExtractor): class WordExtractor(BaseExtractor):
"""Load docx files. """Load docx files.
Args: Args:
file_path: Path to the file to load. file_path: Path to the file to load.
""" """
@ -51,9 +50,9 @@ class WordExtractor(BaseExtractor):
self.web_path = self.file_path self.web_path = self.file_path
# TODO: use a better way to handle the file # TODO: use a better way to handle the file
self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115 with tempfile.NamedTemporaryFile(delete=False) as self.temp_file:
self.temp_file.write(r.content) self.temp_file.write(r.content)
self.file_path = self.temp_file.name self.file_path = self.temp_file.name
elif not os.path.isfile(self.file_path): elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url") raise ValueError(f"File path {self.file_path} is not a valid file or url")
@ -230,7 +229,7 @@ class WordExtractor(BaseExtractor):
for i in url_pattern.findall(x.text): for i in url_pattern.findall(x.text):
hyperlinks_url = str(i) hyperlinks_url = str(i)
except Exception as e: except Exception as e:
logger.error(e) logger.exception(e)
def parse_paragraph(paragraph): def parse_paragraph(paragraph):
paragraph_content = [] paragraph_content = []

View File

@ -98,7 +98,7 @@ class ToolFileManager:
response.raise_for_status() response.raise_for_status()
blob = response.content blob = response.content
except Exception as e: except Exception as e:
logger.error(f"Failed to download file from {file_url}: {e}") logger.exception(f"Failed to download file from {file_url}: {e}")
raise raise
mimetype = guess_type(file_url)[0] or "octet/stream" mimetype = guess_type(file_url)[0] or "octet/stream"

View File

@ -526,7 +526,7 @@ class ToolManager:
yield provider yield provider
except Exception as e: except Exception as e:
logger.error(f"load builtin provider error: {e}") logger.exception(f"load builtin provider {provider} error: {e}")
continue continue
# set builtin providers loaded # set builtin providers loaded
cls._builtin_providers_loaded = True cls._builtin_providers_loaded = True

View File

@ -127,7 +127,9 @@ class FeishuRequest:
"folder_token": folder_token, "folder_token": folder_token,
} }
res = self._send_request(url, payload=payload) res = self._send_request(url, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def write_document(self, document_id: str, content: str, position: str = "end") -> dict: def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
url = f"{self.API_BASE_URL}/document/write_document" url = f"{self.API_BASE_URL}/document/write_document"
@ -135,7 +137,7 @@ class FeishuRequest:
res = self._send_request(url, payload=payload) res = self._send_request(url, payload=payload)
return res return res
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> dict: def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str:
""" """
API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/raw_content API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/raw_content
Example Response: Example Response:
@ -154,7 +156,9 @@ class FeishuRequest:
} }
url = f"{self.API_BASE_URL}/document/get_document_content" url = f"{self.API_BASE_URL}/document/get_document_content"
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data").get("content") if "data" in res:
return res.get("data").get("content")
return ""
def list_document_blocks( def list_document_blocks(
self, document_id: str, page_token: str, user_id_type: str = "open_id", page_size: int = 500 self, document_id: str, page_token: str, user_id_type: str = "open_id", page_size: int = 500
@ -170,7 +174,9 @@ class FeishuRequest:
} }
url = f"{self.API_BASE_URL}/document/list_document_blocks" url = f"{self.API_BASE_URL}/document/list_document_blocks"
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
""" """
@ -186,7 +192,9 @@ class FeishuRequest:
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
} }
res = self._send_request(url, params=params, payload=payload) res = self._send_request(url, params=params, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
url = f"{self.API_BASE_URL}/message/send_webhook_message" url = f"{self.API_BASE_URL}/message/send_webhook_message"
@ -220,7 +228,9 @@ class FeishuRequest:
"page_size": page_size, "page_size": page_size,
} }
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def get_thread_messages( def get_thread_messages(
self, container_id: str, page_token: str, sort_type: str = "ByCreateTimeAsc", page_size: int = 20 self, container_id: str, page_token: str, sort_type: str = "ByCreateTimeAsc", page_size: int = 20
@ -236,7 +246,9 @@ class FeishuRequest:
"page_size": page_size, "page_size": page_size,
} }
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
# 创建任务 # 创建任务
@ -249,7 +261,9 @@ class FeishuRequest:
"description": description, "description": description,
} }
res = self._send_request(url, payload=payload) res = self._send_request(url, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def update_task( def update_task(
self, task_guid: str, summary: str, start_time: str, end_time: str, completed_time: str, description: str self, task_guid: str, summary: str, start_time: str, end_time: str, completed_time: str, description: str
@ -265,7 +279,9 @@ class FeishuRequest:
"description": description, "description": description,
} }
res = self._send_request(url, method="PATCH", payload=payload) res = self._send_request(url, method="PATCH", payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def delete_task(self, task_guid: str) -> dict: def delete_task(self, task_guid: str) -> dict:
# 删除任务 # 删除任务
@ -297,7 +313,9 @@ class FeishuRequest:
"page_size": page_size, "page_size": page_size,
} }
res = self._send_request(url, payload=payload) res = self._send_request(url, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
url = f"{self.API_BASE_URL}/calendar/get_primary_calendar" url = f"{self.API_BASE_URL}/calendar/get_primary_calendar"
@ -305,7 +323,9 @@ class FeishuRequest:
"user_id_type": user_id_type, "user_id_type": user_id_type,
} }
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def create_event( def create_event(
self, self,
@ -328,7 +348,9 @@ class FeishuRequest:
"attendee_ability": attendee_ability, "attendee_ability": attendee_ability,
} }
res = self._send_request(url, payload=payload) res = self._send_request(url, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def update_event( def update_event(
self, self,
@ -374,7 +396,9 @@ class FeishuRequest:
"page_size": page_size, "page_size": page_size,
} }
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def search_events( def search_events(
self, self,
@ -395,7 +419,9 @@ class FeishuRequest:
"page_size": page_size, "page_size": page_size,
} }
res = self._send_request(url, payload=payload) res = self._send_request(url, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
# 参加日程参会人 # 参加日程参会人
@ -406,7 +432,9 @@ class FeishuRequest:
"need_notification": need_notification, "need_notification": need_notification,
} }
res = self._send_request(url, payload=payload) res = self._send_request(url, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def create_spreadsheet( def create_spreadsheet(
self, self,
@ -420,7 +448,9 @@ class FeishuRequest:
"folder_token": folder_token, "folder_token": folder_token,
} }
res = self._send_request(url, payload=payload) res = self._send_request(url, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def get_spreadsheet( def get_spreadsheet(
self, self,
@ -434,7 +464,9 @@ class FeishuRequest:
"user_id_type": user_id_type, "user_id_type": user_id_type,
} }
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def list_spreadsheet_sheets( def list_spreadsheet_sheets(
self, self,
@ -446,7 +478,9 @@ class FeishuRequest:
"spreadsheet_token": spreadsheet_token, "spreadsheet_token": spreadsheet_token,
} }
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def add_rows( def add_rows(
self, self,
@ -466,7 +500,9 @@ class FeishuRequest:
"values": values, "values": values,
} }
res = self._send_request(url, payload=payload) res = self._send_request(url, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def add_cols( def add_cols(
self, self,
@ -486,7 +522,9 @@ class FeishuRequest:
"values": values, "values": values,
} }
res = self._send_request(url, payload=payload) res = self._send_request(url, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def read_rows( def read_rows(
self, self,
@ -508,7 +546,9 @@ class FeishuRequest:
"user_id_type": user_id_type, "user_id_type": user_id_type,
} }
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def read_cols( def read_cols(
self, self,
@ -530,7 +570,9 @@ class FeishuRequest:
"user_id_type": user_id_type, "user_id_type": user_id_type,
} }
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def read_table( def read_table(
self, self,
@ -552,7 +594,9 @@ class FeishuRequest:
"user_id_type": user_id_type, "user_id_type": user_id_type,
} }
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def create_base( def create_base(
self, self,
@ -566,7 +610,9 @@ class FeishuRequest:
"folder_token": folder_token, "folder_token": folder_token,
} }
res = self._send_request(url, payload=payload) res = self._send_request(url, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def add_records( def add_records(
self, self,
@ -588,7 +634,9 @@ class FeishuRequest:
"records": convert_add_records(records), "records": convert_add_records(records),
} }
res = self._send_request(url, params=params, payload=payload) res = self._send_request(url, params=params, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def update_records( def update_records(
self, self,
@ -610,7 +658,9 @@ class FeishuRequest:
"records": convert_update_records(records), "records": convert_update_records(records),
} }
res = self._send_request(url, params=params, payload=payload) res = self._send_request(url, params=params, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def delete_records( def delete_records(
self, self,
@ -637,7 +687,9 @@ class FeishuRequest:
"records": record_id_list, "records": record_id_list,
} }
res = self._send_request(url, params=params, payload=payload) res = self._send_request(url, params=params, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def search_record( def search_record(
self, self,
@ -701,7 +753,10 @@ class FeishuRequest:
if automatic_fields: if automatic_fields:
payload["automatic_fields"] = automatic_fields payload["automatic_fields"] = automatic_fields
res = self._send_request(url, params=params, payload=payload) res = self._send_request(url, params=params, payload=payload)
return res.get("data")
if "data" in res:
return res.get("data")
return res
def get_base_info( def get_base_info(
self, self,
@ -713,7 +768,9 @@ class FeishuRequest:
"app_token": app_token, "app_token": app_token,
} }
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def create_table( def create_table(
self, self,
@ -741,7 +798,9 @@ class FeishuRequest:
if default_view_name: if default_view_name:
payload["default_view_name"] = default_view_name payload["default_view_name"] = default_view_name
res = self._send_request(url, params=params, payload=payload) res = self._send_request(url, params=params, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def delete_tables( def delete_tables(
self, self,
@ -774,8 +833,11 @@ class FeishuRequest:
"table_ids": table_id_list, "table_ids": table_id_list,
"table_names": table_name_list, "table_names": table_name_list,
} }
res = self._send_request(url, params=params, payload=payload) res = self._send_request(url, params=params, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res
def list_tables( def list_tables(
self, self,
@ -791,7 +853,9 @@ class FeishuRequest:
"page_size": page_size, "page_size": page_size,
} }
res = self._send_request(url, method="GET", params=params) res = self._send_request(url, method="GET", params=params)
return res.get("data") if "data" in res:
return res.get("data")
return res
def read_records( def read_records(
self, self,
@ -819,4 +883,6 @@ class FeishuRequest:
"user_id_type": user_id_type, "user_id_type": user_id_type,
} }
res = self._send_request(url, method="GET", params=params, payload=payload) res = self._send_request(url, method="GET", params=params, payload=payload)
return res.get("data") if "data" in res:
return res.get("data")
return res

View File

@ -0,0 +1,820 @@
import json
from typing import Optional
import httpx
from core.tools.errors import ToolProviderCredentialValidationError
from extensions.ext_redis import redis_client
def lark_auth(credentials):
app_id = credentials.get("app_id")
app_secret = credentials.get("app_secret")
if not app_id or not app_secret:
raise ToolProviderCredentialValidationError("app_id and app_secret is required")
try:
assert LarkRequest(app_id, app_secret).tenant_access_token is not None
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
class LarkRequest:
API_BASE_URL = "https://lark-plugin-api.solutionsuite.ai/lark-plugin"
def __init__(self, app_id: str, app_secret: str):
self.app_id = app_id
self.app_secret = app_secret
def convert_add_records(self, json_str):
try:
data = json.loads(json_str)
if not isinstance(data, list):
raise ValueError("Parsed data must be a list")
converted_data = [{"fields": json.dumps(item, ensure_ascii=False)} for item in data]
return converted_data
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
except Exception as e:
raise ValueError(f"An error occurred while processing the data: {e}")
def convert_update_records(self, json_str):
try:
data = json.loads(json_str)
if not isinstance(data, list):
raise ValueError("Parsed data must be a list")
converted_data = [
{"fields": json.dumps(record["fields"], ensure_ascii=False), "record_id": record["record_id"]}
for record in data
if "fields" in record and "record_id" in record
]
if len(converted_data) != len(data):
raise ValueError("Each record must contain 'fields' and 'record_id'")
return converted_data
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
except Exception as e:
raise ValueError(f"An error occurred while processing the data: {e}")
@property
def tenant_access_token(self) -> str:
feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token"
if redis_client.exists(feishu_tenant_access_token):
return redis_client.get(feishu_tenant_access_token).decode()
res = self.get_tenant_access_token(self.app_id, self.app_secret)
redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token"))
if "tenant_access_token" in res:
return res.get("tenant_access_token")
return ""
def _send_request(
self,
url: str,
method: str = "post",
require_token: bool = True,
payload: Optional[dict] = None,
params: Optional[dict] = None,
):
headers = {
"Content-Type": "application/json",
"user-agent": "Dify",
}
if require_token:
headers["tenant-access-token"] = f"{self.tenant_access_token}"
res = httpx.request(method=method, url=url, headers=headers, json=payload, params=params, timeout=30).json()
if res.get("code") != 0:
raise Exception(res)
return res
def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict:
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
payload = {"app_id": app_id, "app_secret": app_secret}
res = self._send_request(url, require_token=False, payload=payload)
return res
def create_document(self, title: str, content: str, folder_token: str) -> dict:
url = f"{self.API_BASE_URL}/document/create_document"
payload = {
"title": title,
"content": content,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
return res
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
url = f"{self.API_BASE_URL}/document/write_document"
payload = {"document_id": document_id, "content": content, "position": position}
res = self._send_request(url, payload=payload)
return res
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict:
params = {
"document_id": document_id,
"mode": mode,
"lang": lang,
}
url = f"{self.API_BASE_URL}/document/get_document_content"
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data").get("content")
return ""
def list_document_blocks(
self, document_id: str, page_token: str, user_id_type: str = "open_id", page_size: int = 500
) -> dict:
params = {
"user_id_type": user_id_type,
"document_id": document_id,
"page_size": page_size,
"page_token": page_token,
}
url = f"{self.API_BASE_URL}/document/list_document_blocks"
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
url = f"{self.API_BASE_URL}/message/send_bot_message"
params = {
"receive_id_type": receive_id_type,
}
payload = {
"receive_id": receive_id,
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
res = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
return res
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
url = f"{self.API_BASE_URL}/message/send_webhook_message"
payload = {
"webhook": webhook,
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
res = self._send_request(url, require_token=False, payload=payload)
return res
def get_chat_messages(
self,
container_id: str,
start_time: str,
end_time: str,
page_token: str,
sort_type: str = "ByCreateTimeAsc",
page_size: int = 20,
) -> dict:
url = f"{self.API_BASE_URL}/message/get_chat_messages"
params = {
"container_id": container_id,
"start_time": start_time,
"end_time": end_time,
"sort_type": sort_type,
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def get_thread_messages(
self, container_id: str, page_token: str, sort_type: str = "ByCreateTimeAsc", page_size: int = 20
) -> dict:
url = f"{self.API_BASE_URL}/message/get_thread_messages"
params = {
"container_id": container_id,
"sort_type": sort_type,
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
url = f"{self.API_BASE_URL}/task/create_task"
payload = {
"summary": summary,
"start_time": start_time,
"end_time": end_time,
"completed_at": completed_time,
"description": description,
}
res = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
return res
def update_task(
self, task_guid: str, summary: str, start_time: str, end_time: str, completed_time: str, description: str
) -> dict:
url = f"{self.API_BASE_URL}/task/update_task"
payload = {
"task_guid": task_guid,
"summary": summary,
"start_time": start_time,
"end_time": end_time,
"completed_time": completed_time,
"description": description,
}
res = self._send_request(url, method="PATCH", payload=payload)
if "data" in res:
return res.get("data")
return res
def delete_task(self, task_guid: str) -> dict:
url = f"{self.API_BASE_URL}/task/delete_task"
payload = {
"task_guid": task_guid,
}
res = self._send_request(url, method="DELETE", payload=payload)
if "data" in res:
return res.get("data")
return res
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
url = f"{self.API_BASE_URL}/task/add_members"
payload = {
"task_guid": task_guid,
"member_phone_or_email": member_phone_or_email,
"member_role": member_role,
}
res = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
return res
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
url = f"{self.API_BASE_URL}/wiki/get_wiki_nodes"
payload = {
"space_id": space_id,
"parent_node_token": parent_node_token,
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
return res
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
url = f"{self.API_BASE_URL}/calendar/get_primary_calendar"
params = {
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def create_event(
self,
summary: str,
description: str,
start_time: str,
end_time: str,
attendee_ability: str,
need_notification: bool = True,
auto_record: bool = False,
) -> dict:
url = f"{self.API_BASE_URL}/calendar/create_event"
payload = {
"summary": summary,
"description": description,
"need_notification": need_notification,
"start_time": start_time,
"end_time": end_time,
"auto_record": auto_record,
"attendee_ability": attendee_ability,
}
res = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
return res
def update_event(
self,
event_id: str,
summary: str,
description: str,
need_notification: bool,
start_time: str,
end_time: str,
auto_record: bool,
) -> dict:
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
payload = {}
if summary:
payload["summary"] = summary
if description:
payload["description"] = description
if start_time:
payload["start_time"] = start_time
if end_time:
payload["end_time"] = end_time
if need_notification:
payload["need_notification"] = need_notification
if auto_record:
payload["auto_record"] = auto_record
res = self._send_request(url, method="PATCH", payload=payload)
return res
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
url = f"{self.API_BASE_URL}/calendar/delete_event/{event_id}"
params = {
"need_notification": need_notification,
}
res = self._send_request(url, method="DELETE", params=params)
return res
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
url = f"{self.API_BASE_URL}/calendar/list_events"
params = {
"start_time": start_time,
"end_time": end_time,
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def search_events(
self,
query: str,
start_time: str,
end_time: str,
page_token: str,
user_id_type: str = "open_id",
page_size: int = 20,
) -> dict:
url = f"{self.API_BASE_URL}/calendar/search_events"
payload = {
"query": query,
"start_time": start_time,
"end_time": end_time,
"page_token": page_token,
"user_id_type": user_id_type,
"page_size": page_size,
}
res = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
return res
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
url = f"{self.API_BASE_URL}/calendar/add_event_attendees"
payload = {
"event_id": event_id,
"attendee_phone_or_email": attendee_phone_or_email,
"need_notification": need_notification,
}
res = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
return res
def create_spreadsheet(
self,
title: str,
folder_token: str,
) -> dict:
url = f"{self.API_BASE_URL}/spreadsheet/create_spreadsheet"
payload = {
"title": title,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
return res
def get_spreadsheet(
self,
spreadsheet_token: str,
user_id_type: str = "open_id",
) -> dict:
url = f"{self.API_BASE_URL}/spreadsheet/get_spreadsheet"
params = {
"spreadsheet_token": spreadsheet_token,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def list_spreadsheet_sheets(
self,
spreadsheet_token: str,
) -> dict:
url = f"{self.API_BASE_URL}/spreadsheet/list_spreadsheet_sheets"
params = {
"spreadsheet_token": spreadsheet_token,
}
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def add_rows(
self,
spreadsheet_token: str,
sheet_id: str,
sheet_name: str,
length: int,
values: str,
) -> dict:
url = f"{self.API_BASE_URL}/spreadsheet/add_rows"
payload = {
"spreadsheet_token": spreadsheet_token,
"sheet_id": sheet_id,
"sheet_name": sheet_name,
"length": length,
"values": values,
}
res = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
return res
def add_cols(
self,
spreadsheet_token: str,
sheet_id: str,
sheet_name: str,
length: int,
values: str,
) -> dict:
url = f"{self.API_BASE_URL}/spreadsheet/add_cols"
payload = {
"spreadsheet_token": spreadsheet_token,
"sheet_id": sheet_id,
"sheet_name": sheet_name,
"length": length,
"values": values,
}
res = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
return res
def read_rows(
self,
spreadsheet_token: str,
sheet_id: str,
sheet_name: str,
start_row: int,
num_rows: int,
user_id_type: str = "open_id",
) -> dict:
url = f"{self.API_BASE_URL}/spreadsheet/read_rows"
params = {
"spreadsheet_token": spreadsheet_token,
"sheet_id": sheet_id,
"sheet_name": sheet_name,
"start_row": start_row,
"num_rows": num_rows,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def read_cols(
self,
spreadsheet_token: str,
sheet_id: str,
sheet_name: str,
start_col: int,
num_cols: int,
user_id_type: str = "open_id",
) -> dict:
url = f"{self.API_BASE_URL}/spreadsheet/read_cols"
params = {
"spreadsheet_token": spreadsheet_token,
"sheet_id": sheet_id,
"sheet_name": sheet_name,
"start_col": start_col,
"num_cols": num_cols,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def read_table(
self,
spreadsheet_token: str,
sheet_id: str,
sheet_name: str,
num_range: str,
query: str,
user_id_type: str = "open_id",
) -> dict:
url = f"{self.API_BASE_URL}/spreadsheet/read_table"
params = {
"spreadsheet_token": spreadsheet_token,
"sheet_id": sheet_id,
"sheet_name": sheet_name,
"range": num_range,
"query": query,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def create_base(
self,
name: str,
folder_token: str,
) -> dict:
url = f"{self.API_BASE_URL}/base/create_base"
payload = {
"name": name,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
return res
def add_records(
self,
app_token: str,
table_id: str,
table_name: str,
records: str,
user_id_type: str = "open_id",
) -> dict:
url = f"{self.API_BASE_URL}/base/add_records"
params = {
"app_token": app_token,
"table_id": table_id,
"table_name": table_name,
"user_id_type": user_id_type,
}
payload = {
"records": self.convert_add_records(records),
}
res = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
return res
def update_records(
self,
app_token: str,
table_id: str,
table_name: str,
records: str,
user_id_type: str,
) -> dict:
url = f"{self.API_BASE_URL}/base/update_records"
params = {
"app_token": app_token,
"table_id": table_id,
"table_name": table_name,
"user_id_type": user_id_type,
}
payload = {
"records": self.convert_update_records(records),
}
res = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
return res
def delete_records(
self,
app_token: str,
table_id: str,
table_name: str,
record_ids: str,
) -> dict:
url = f"{self.API_BASE_URL}/base/delete_records"
params = {
"app_token": app_token,
"table_id": table_id,
"table_name": table_name,
}
if not record_ids:
record_id_list = []
else:
try:
record_id_list = json.loads(record_ids)
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
payload = {
"records": record_id_list,
}
res = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
return res
def search_record(
self,
app_token: str,
table_id: str,
table_name: str,
view_id: str,
field_names: str,
sort: str,
filters: str,
page_token: str,
automatic_fields: bool = False,
user_id_type: str = "open_id",
page_size: int = 20,
) -> dict:
url = f"{self.API_BASE_URL}/base/search_record"
params = {
"app_token": app_token,
"table_id": table_id,
"table_name": table_name,
"user_id_type": user_id_type,
"page_token": page_token,
"page_size": page_size,
}
if not field_names:
field_name_list = []
else:
try:
field_name_list = json.loads(field_names)
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
if not sort:
sort_list = []
else:
try:
sort_list = json.loads(sort)
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
if not filters:
filter_dict = {}
else:
try:
filter_dict = json.loads(filters)
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
payload = {}
if view_id:
payload["view_id"] = view_id
if field_names:
payload["field_names"] = field_name_list
if sort:
payload["sort"] = sort_list
if filters:
payload["filter"] = filter_dict
if automatic_fields:
payload["automatic_fields"] = automatic_fields
res = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
return res
def get_base_info(
self,
app_token: str,
) -> dict:
url = f"{self.API_BASE_URL}/base/get_base_info"
params = {
"app_token": app_token,
}
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def create_table(
self,
app_token: str,
table_name: str,
default_view_name: str,
fields: str,
) -> dict:
url = f"{self.API_BASE_URL}/base/create_table"
params = {
"app_token": app_token,
}
if not fields:
fields_list = []
else:
try:
fields_list = json.loads(fields)
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
payload = {
"name": table_name,
"fields": fields_list,
}
if default_view_name:
payload["default_view_name"] = default_view_name
res = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
return res
def delete_tables(
self,
app_token: str,
table_ids: str,
table_names: str,
) -> dict:
url = f"{self.API_BASE_URL}/base/delete_tables"
params = {
"app_token": app_token,
}
if not table_ids:
table_id_list = []
else:
try:
table_id_list = json.loads(table_ids)
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
if not table_names:
table_name_list = []
else:
try:
table_name_list = json.loads(table_names)
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
payload = {
"table_ids": table_id_list,
"table_names": table_name_list,
}
res = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
return res
def list_tables(
self,
app_token: str,
page_token: str,
page_size: int = 20,
) -> dict:
url = f"{self.API_BASE_URL}/base/list_tables"
params = {
"app_token": app_token,
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
return res
def read_records(
self,
app_token: str,
table_id: str,
table_name: str,
record_ids: str,
user_id_type: str = "open_id",
) -> dict:
url = f"{self.API_BASE_URL}/base/read_records"
params = {
"app_token": app_token,
"table_id": table_id,
"table_name": table_name,
}
if not record_ids:
record_id_list = []
else:
try:
record_id_list = json.loads(record_ids)
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
payload = {
"record_ids": record_id_list,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="POST", params=params, payload=payload)
if "data" in res:
return res.get("data")
return res

View File

@ -69,7 +69,7 @@ class BaseNode(Generic[GenericNodeData]):
try: try:
result = self._run() result = self._run()
except Exception as e: except Exception as e:
logger.error(f"Node {self.node_id} failed to run: {e}") logger.exception(f"Node {self.node_id} failed to run: {e}")
result = NodeRunResult( result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
error=str(e), error=str(e),

View File

@ -97,15 +97,6 @@ class Executor:
headers = self.variable_pool.convert_template(self.node_data.headers).text headers = self.variable_pool.convert_template(self.node_data.headers).text
self.headers = _plain_text_to_dict(headers) self.headers = _plain_text_to_dict(headers)
body = self.node_data.body
if body is None:
return
if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE:
self.headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
if body.type == "form-data":
self.boundary = f"----WebKitFormBoundary{_generate_random_string(16)}"
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
def _init_body(self): def _init_body(self):
body = self.node_data.body body = self.node_data.body
if body is not None: if body is not None:
@ -154,9 +145,8 @@ class Executor:
for k, v in files.items() for k, v in files.items()
if v.related_id is not None if v.related_id is not None
} }
self.data = form_data self.data = form_data
self.files = files self.files = files or None
def _assembling_headers(self) -> dict[str, Any]: def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.auth) authorization = deepcopy(self.auth)
@ -217,6 +207,7 @@ class Executor:
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True, "follow_redirects": True,
} }
# request_args = {k: v for k, v in request_args.items() if v is not None}
response = getattr(ssrf_proxy, self.method)(**request_args) response = getattr(ssrf_proxy, self.method)(**request_args)
return response return response
@ -244,6 +235,13 @@ class Executor:
raw += f"Host: {url_parts.netloc}\r\n" raw += f"Host: {url_parts.netloc}\r\n"
headers = self._assembling_headers() headers = self._assembling_headers()
body = self.node_data.body
boundary = f"----WebKitFormBoundary{_generate_random_string(16)}"
if body:
if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE:
headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
if body.type == "form-data":
headers["Content-Type"] = f"multipart/form-data; boundary={boundary}"
for k, v in headers.items(): for k, v in headers.items():
if self.auth.type == "api-key": if self.auth.type == "api-key":
authorization_header = "Authorization" authorization_header = "Authorization"
@ -256,7 +254,6 @@ class Executor:
body = "" body = ""
if self.files: if self.files:
boundary = self.boundary
for k, v in self.files.items(): for k, v in self.files.items():
body += f"--{boundary}\r\n" body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
@ -271,7 +268,6 @@ class Executor:
elif self.data and self.node_data.body.type == "x-www-form-urlencoded": elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
body = urlencode(self.data) body = urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data": elif self.data and self.node_data.body.type == "form-data":
boundary = self.boundary
for key, value in self.data.items(): for key, value in self.data.items():
body += f"--{boundary}\r\n" body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'

View File

@ -14,6 +14,7 @@ from core.model_runtime.entities import (
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
TextPromptMessageContent, TextPromptMessageContent,
VideoPromptMessageContent,
) )
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -560,7 +561,9 @@ class LLMNode(BaseNode[LLMNodeData]):
# cuz vision detail is related to the configuration from FileUpload feature. # cuz vision detail is related to the configuration from FileUpload feature.
content_item.detail = vision_detail content_item.detail = vision_detail
prompt_message_content.append(content_item) prompt_message_content.append(content_item)
elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent): elif isinstance(
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
):
prompt_message_content.append(content_item) prompt_message_content.append(content_item)
if len(prompt_message_content) > 1: if len(prompt_message_content) > 1:

View File

@ -127,7 +127,7 @@ class QuestionClassifierNode(LLMNode):
category_id = category_id_result category_id = category_id_result
except OutputParserError: except OutputParserError:
logging.error(f"Failed to parse result text: {result_text}") logging.exception(f"Failed to parse result text: {result_text}")
try: try:
process_data = { process_data = {
"model_mode": model_config.mode, "model_mode": model_config.mode,

View File

@ -1,3 +1,4 @@
import posixpath
from collections.abc import Generator from collections.abc import Generator
import oss2 as aliyun_s3 import oss2 as aliyun_s3
@ -50,9 +51,4 @@ class AliyunOssStorage(BaseStorage):
self.client.delete_object(self.__wrapper_folder_filename(filename)) self.client.delete_object(self.__wrapper_folder_filename(filename))
def __wrapper_folder_filename(self, filename) -> str: def __wrapper_folder_filename(self, filename) -> str:
if self.folder: return posixpath.join(self.folder, filename) if self.folder else filename
if self.folder.endswith("/"):
filename = self.folder + filename
else:
filename = self.folder + "/" + filename
return filename

View File

@ -202,6 +202,10 @@ simple_conversation_fields = {
"updated_at": TimestampField, "updated_at": TimestampField,
} }
conversation_delete_fields = {
"result": fields.String,
}
conversation_infinite_scroll_pagination_fields = { conversation_infinite_scroll_pagination_fields = {
"limit": fields.Integer, "limit": fields.Integer,
"has_more": fields.Boolean, "has_more": fields.Boolean,

View File

@ -39,13 +39,13 @@ class SMTPClient:
smtp.sendmail(self._from, mail["to"], msg.as_string()) smtp.sendmail(self._from, mail["to"], msg.as_string())
except smtplib.SMTPException as e: except smtplib.SMTPException as e:
logging.error(f"SMTP error occurred: {str(e)}") logging.exception(f"SMTP error occurred: {str(e)}")
raise raise
except TimeoutError as e: except TimeoutError as e:
logging.error(f"Timeout occurred while sending email: {str(e)}") logging.exception(f"Timeout occurred while sending email: {str(e)}")
raise raise
except Exception as e: except Exception as e:
logging.error(f"Unexpected error occurred while sending email: {str(e)}") logging.exception(f"Unexpected error occurred while sending email: {str(e)}")
raise raise
finally: finally:
if smtp: if smtp:

View File

@ -34,6 +34,7 @@ select = [
"RUF101", # redirected-noqa "RUF101", # redirected-noqa
"S506", # unsafe-yaml-load "S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules "SIM", # flake8-simplify rules
"TRY400", # error-instead-of-exception
"UP", # pyupgrade rules "UP", # pyupgrade rules
"W191", # tab-indentation "W191", # tab-indentation
"W605", # invalid-escape-sequence "W605", # invalid-escape-sequence

View File

@ -821,7 +821,7 @@ class RegisterService:
db.session.rollback() db.session.rollback()
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
logging.error(f"Register failed: {e}") logging.exception(f"Register failed: {e}")
raise AccountRegisterError(f"Registration failed: {e}") from e raise AccountRegisterError(f"Registration failed: {e}") from e
return account return account

View File

@ -160,4 +160,5 @@ class ConversationService:
conversation = cls.get_conversation(app_model, conversation_id, user) conversation = cls.get_conversation(app_model, conversation_id, user)
conversation.is_deleted = True conversation.is_deleted = True
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit() db.session.commit()

View File

@ -195,7 +195,7 @@ class ApiToolManageService:
# try to parse schema, avoid SSRF attack # try to parse schema, avoid SSRF attack
ApiToolManageService.parser_api_schema(schema) ApiToolManageService.parser_api_schema(schema)
except Exception as e: except Exception as e:
logger.error(f"parse api schema error: {str(e)}") logger.exception(f"parse api schema error: {str(e)}")
raise ValueError("invalid schema, please check the url you provided") raise ValueError("invalid schema, please check the url you provided")
return {"schema": schema} return {"schema": schema}

View File

@ -196,8 +196,7 @@ class ToolTransformService:
username = user.name username = user.name
except Exception as e: except Exception as e:
logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}") logger.exception(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
# add provider into providers # add provider into providers
credentials = db_provider.credentials credentials = db_provider.credentials
result = ToolProviderApiEntity( result = ToolProviderApiEntity(

View File

@ -196,3 +196,72 @@ def test_extract_selectors_from_template_with_newline():
) )
assert executor.params == {"test": "line1\nline2"} assert executor.params == {"test": "line1\nline2"}
def test_executor_with_form_data():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
variable_pool.add(["pre_node_id", "number_field"], 42)
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test Form Data",
method="post",
url="https://api.example.com/upload",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: multipart/form-data",
params="",
body=HttpRequestNodeBody(
type="form-data",
data=[
BodyData(
key="text_field",
type="text",
value="{{#pre_node_id.text_field#}}",
),
BodyData(
key="number_field",
type="text",
value="{{#pre_node_id.number_field#}}",
),
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/upload"
assert "Content-Type" in executor.headers
assert "multipart/form-data" in executor.headers["Content-Type"]
assert executor.params == {}
assert executor.json is None
assert executor.files is None
assert executor.content is None
# Check that the form data is correctly loaded in executor.data
assert isinstance(executor.data, dict)
assert "text_field" in executor.data
assert executor.data["text_field"] == "Hello, World!"
assert "number_field" in executor.data
assert executor.data["number_field"] == "42"
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /upload HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: multipart/form-data" in raw_request
assert "text_field" in raw_request
assert "Hello, World!" in raw_request
assert "number_field" in raw_request
assert "42" in raw_request

View File

@ -1115,7 +1115,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
title="Request" title="Request"
tag="POST" tag="POST"
label="/datasets/{dataset_id}/retrieve" label="/datasets/{dataset_id}/retrieve"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{ targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
"query": "test", "query": "test",
"retrieval_model": { "retrieval_model": {
"search_method": "keyword_search", "search_method": "keyword_search",

View File

@ -1116,7 +1116,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
title="Request" title="Request"
tag="POST" tag="POST"
label="/datasets/{dataset_id}/retrieve" label="/datasets/{dataset_id}/retrieve"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{ targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
"query": "test", "query": "test",
"retrieval_model": { "retrieval_model": {
"search_method": "keyword_search", "search_method": "keyword_search",

View File

@ -466,8 +466,8 @@ const Configuration: FC = () => {
transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
}, },
enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled), enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled),
allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image], allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image, SupportUploadFileTypes.video],
allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image], ...FILE_EXTS[SupportUploadFileTypes.video]].map(ext => `.${ext}`),
allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3, number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3,
fileUploadConfig: fileUploadConfigResponse, fileUploadConfig: fileUploadConfigResponse,

View File

@ -1,6 +1,5 @@
import { import {
useCallback, useCallback,
useRef,
useState, useState,
} from 'react' } from 'react'
import Textarea from 'rc-textarea' import Textarea from 'rc-textarea'
@ -63,7 +62,6 @@ const ChatInputArea = ({
isMultipleLine, isMultipleLine,
} = useTextAreaHeight() } = useTextAreaHeight()
const [query, setQuery] = useState('') const [query, setQuery] = useState('')
const isUseInputMethod = useRef(false)
const [showVoiceInput, setShowVoiceInput] = useState(false) const [showVoiceInput, setShowVoiceInput] = useState(false)
const filesStore = useFileStore() const filesStore = useFileStore()
const { const {
@ -95,20 +93,11 @@ const ChatInputArea = ({
} }
} }
const handleKeyUp = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter') {
e.preventDefault()
// prevent send message when using input method enter
if (!e.shiftKey && !isUseInputMethod.current)
handleSend()
}
}
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => { const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
isUseInputMethod.current = e.nativeEvent.isComposing if (e.key === 'Enter' && !e.shiftKey && !e.nativeEvent.isComposing) {
if (e.key === 'Enter' && !e.shiftKey) {
setQuery(query.replace(/\n$/, ''))
e.preventDefault() e.preventDefault()
setQuery(query.replace(/\n$/, ''))
handleSend()
} }
} }
@ -165,7 +154,6 @@ const ChatInputArea = ({
setQuery(e.target.value) setQuery(e.target.value)
handleTextareaResize() handleTextareaResize()
}} }}
onKeyUp={handleKeyUp}
onKeyDown={handleKeyDown} onKeyDown={handleKeyDown}
onPaste={handleClipboardPasteFile} onPaste={handleClipboardPasteFile}
onDragEnter={handleDragFileEnter} onDragEnter={handleDragFileEnter}

View File

@ -120,7 +120,7 @@ const ConfigCredential: FC<Props> = ({
<input <input
value={tempCredential.api_key_header} value={tempCredential.api_key_header}
onChange={e => setTempCredential({ ...tempCredential, api_key_header: e.target.value })} onChange={e => setTempCredential({ ...tempCredential, api_key_header: e.target.value })}
className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow' className='w-full h-10 px-3 text-sm font-normal border border-transparent bg-gray-100 rounded-lg grow outline-none focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs'
placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!} placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!}
/> />
</div> </div>
@ -129,7 +129,7 @@ const ConfigCredential: FC<Props> = ({
<input <input
value={tempCredential.api_key_value} value={tempCredential.api_key_value}
onChange={e => setTempCredential({ ...tempCredential, api_key_value: e.target.value })} onChange={e => setTempCredential({ ...tempCredential, api_key_value: e.target.value })}
className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow' className='w-full h-10 px-3 text-sm font-normal border border-transparent bg-gray-100 rounded-lg grow outline-none focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs'
placeholder={t('tools.createTool.authMethod.types.apiValuePlaceholder')!} placeholder={t('tools.createTool.authMethod.types.apiValuePlaceholder')!}
/> />
</div> </div>

View File

@ -70,7 +70,7 @@ const GetSchema: FC<Props> = ({
<div className='relative'> <div className='relative'>
<input <input
type='text' type='text'
className='w-[244px] h-8 pl-1.5 pr-[44px] overflow-x-auto border border-gray-200 rounded-lg text-[13px]' className='w-[244px] h-8 pl-1.5 pr-[44px] overflow-x-auto border border-gray-200 rounded-lg text-[13px] focus:outline-none focus:border-components-input-border-active'
placeholder={t('tools.createTool.importFromUrlPlaceHolder')!} placeholder={t('tools.createTool.importFromUrlPlaceHolder')!}
value={importUrl} value={importUrl}
onChange={e => setImportUrl(e.target.value)} onChange={e => setImportUrl(e.target.value)}
@ -89,7 +89,7 @@ const GetSchema: FC<Props> = ({
</div> </div>
)} )}
</div> </div>
<div className='relative' ref={showExamplesRef}> <div className='relative -mt-0.5' ref={showExamplesRef}>
<Button <Button
size='small' size='small'
className='space-x-1' className='space-x-1'

View File

@ -186,8 +186,8 @@ const EditCustomCollectionModal: FC<Props> = ({
positionCenter={isAdd && !positionLeft} positionCenter={isAdd && !positionLeft}
onHide={onHide} onHide={onHide}
title={t(`tools.createTool.${isAdd ? 'title' : 'editTitle'}`)!} title={t(`tools.createTool.${isAdd ? 'title' : 'editTitle'}`)!}
panelClassName='mt-2 !w-[630px]' panelClassName='mt-2 !w-[640px]'
maxWidthClassName='!max-w-[630px]' maxWidthClassName='!max-w-[640px]'
height='calc(100vh - 16px)' height='calc(100vh - 16px)'
headerClassName='!border-b-black/5' headerClassName='!border-b-black/5'
body={ body={

View File

@ -27,8 +27,8 @@ const Contribute = ({ onRefreshData }: Props) => {
const linkUrl = useMemo(() => { const linkUrl = useMemo(() => {
if (language.startsWith('zh_')) if (language.startsWith('zh_'))
return 'https://docs.dify.ai/v/zh-hans/guides/gong-ju/quick-tool-integration' return 'https://docs.dify.ai/zh-hans/guides/tools#ru-he-chuang-jian-zi-ding-yi-gong-ju'
return 'https://docs.dify.ai/tutorials/quick-tool-integration' return 'https://docs.dify.ai/guides/tools#how-to-create-custom-tools'
}, [language]) }, [language])
const [isShowEditCollectionToolModal, setIsShowEditCustomCollectionModal] = useState(false) const [isShowEditCollectionToolModal, setIsShowEditCustomCollectionModal] = useState(false)