merge main

This commit is contained in:
Joel 2024-11-20 18:24:03 +08:00
commit 99ffe43e91
241 changed files with 4972 additions and 1950 deletions

View File

@ -42,6 +42,11 @@ REDIS_SENTINEL_USERNAME=
REDIS_SENTINEL_PASSWORD=
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
# redis Cluster configuration.
REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
# PostgreSQL database configuration
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
@ -234,6 +239,10 @@ ANALYTICDB_ACCOUNT=testaccount
ANALYTICDB_PASSWORD=testpassword
ANALYTICDB_NAMESPACE=dify
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
ANALYTICDB_HOST=gp-test.aliyuncs.com
ANALYTICDB_PORT=5432
ANALYTICDB_MIN_CONNECTION=1
ANALYTICDB_MAX_CONNECTION=5
# OpenSearch configuration
OPENSEARCH_HOST=127.0.0.1

View File

@ -589,7 +589,7 @@ def upgrade_db():
click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e:
logging.exception(f"Database migration failed: {e}")
logging.exception("Failed to execute database migration")
finally:
lock.release()
else:
@ -633,7 +633,7 @@ where sites.id is null limit 1000"""
except Exception as e:
failed_app_ids.append(app_id)
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
logging.exception(f"Fix app related site missing issue failed, error: {e}")
logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
continue
if not processed_count:

View File

@ -616,6 +616,11 @@ class DataSetConfig(BaseSettings):
default=False,
)
PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING: PositiveInt = Field(
description="Interval in days for message cleanup operations - plan: sandbox",
default=30,
)
class WorkspaceConfig(BaseSettings):
"""

View File

@ -68,3 +68,18 @@ class RedisConfig(BaseSettings):
description="Socket timeout in seconds for Redis Sentinel connections",
default=0.1,
)
REDIS_USE_CLUSTERS: bool = Field(
description="Enable Redis Clusters mode for high availability",
default=False,
)
REDIS_CLUSTERS: Optional[str] = Field(
description="Comma-separated list of Redis Clusters nodes (host:port)",
default=None,
)
REDIS_CLUSTERS_PASSWORD: Optional[str] = Field(
description="Password for Redis Clusters authentication (if required)",
default=None,
)

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PositiveInt
class AnalyticdbConfig(BaseModel):
@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel):
description="The password for accessing the specified namespace within the AnalyticDB instance"
" (if namespace feature is enabled).",
)
ANALYTICDB_HOST: Optional[str] = Field(
default=None, description="The host of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_PORT: PositiveInt = Field(
default=5432, description="The port of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.11.1",
default="0.11.2",
)
COMMIT_SHA: str = Field(

View File

@ -9,6 +9,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
enterprise_license_required,
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
@ -28,6 +29,7 @@ class AppListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
"""Get app list"""
@ -149,6 +151,7 @@ class AppApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@get_app_model
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):

View File

@ -70,7 +70,7 @@ class ChatMessageAudioApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logging.exception(f"internal server error, {str(e)}.")
logging.exception("Failed to handle post request to ChatMessageAudioApi")
raise InternalServerError()
@ -128,7 +128,7 @@ class ChatMessageTextApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logging.exception(f"internal server error, {str(e)}.")
logging.exception("Failed to handle post request to ChatMessageTextApi")
raise InternalServerError()
@ -170,7 +170,7 @@ class TextModesApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logging.exception(f"internal server error, {str(e)}.")
logging.exception("Failed to handle get request to TextModesApi")
raise InternalServerError()

View File

@ -12,7 +12,7 @@ from controllers.console.auth.error import (
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
@ -48,7 +48,7 @@ class ForgotPasswordSendEmailApi(Resource):
token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"}
else:
raise NotAllowedRegister()
raise AccountNotFound()
else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)

View File

@ -16,9 +16,9 @@ from controllers.console.auth.error import (
)
from controllers.console.error import (
AccountBannedError,
AccountNotFound,
EmailSendIpLimitError,
NotAllowedCreateWorkspace,
NotAllowedRegister,
)
from controllers.console.wraps import setup_required
from events.tenant_event import tenant_was_created
@ -76,7 +76,7 @@ class LoginApi(Resource):
token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"}
else:
raise NotAllowedRegister()
raise AccountNotFound()
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@ -119,7 +119,7 @@ class ResetPasswordSendEmailApi(Resource):
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
else:
raise NotAllowedRegister()
raise AccountNotFound()
else:
token = AccountService.send_reset_password_email(account=account, language=language)
@ -148,7 +148,7 @@ class EmailCodeLoginSendEmailApi(Resource):
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
else:
raise NotAllowedRegister()
raise AccountNotFound()
else:
token = AccountService.send_email_code_login_email(account=account, language=language)

View File

@ -10,7 +10,7 @@ from controllers.console import api
from controllers.console.apikey import api_key_fields, api_key_list
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
@ -44,6 +44,7 @@ class DatasetListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)

View File

@ -948,7 +948,7 @@ class DocumentRetryApi(DocumentResource):
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception as e:
logging.exception(f"Document {document_id} retry failed: {str(e)}")
logging.exception(f"Failed to retry document, document id: {document_id}")
continue
# retry document
DocumentService.retry_document(dataset_id, retry_documents)

View File

@ -52,8 +52,8 @@ class AccountBannedError(BaseHTTPException):
code = 400
class NotAllowedRegister(BaseHTTPException):
error_code = "unauthorized"
class AccountNotFound(BaseHTTPException):
error_code = "account_not_found"
description = "Account not found."
code = 400
@ -86,3 +86,9 @@ class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class UnauthorizedAndForceLogout(BaseHTTPException):
error_code = "unauthorized_and_force_logout"
description = "Unauthorized and force logout."
code = 401

View File

@ -45,7 +45,7 @@ class RemoteFileUploadApi(Resource):
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3)
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
resp.raise_for_status()
file_info = helpers.guess_file_info_from_response(resp)

View File

@ -14,7 +14,7 @@ from controllers.console.workspace.error import (
InvalidInvitationCodeError,
RepeatPasswordNotMatchError,
)
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.helper import TimestampField, timezone
@ -79,6 +79,7 @@ class AccountProfileApi(Resource):
@login_required
@account_initialization_required
@marshal_with(account_fields)
@enterprise_license_required
def get(self):
return current_user

View File

@ -1,3 +1,5 @@
from urllib import parse
from flask_login import current_user
from flask_restful import Resource, abort, marshal_with, reqparse
@ -57,11 +59,12 @@ class MemberInviteEmailApi(Resource):
token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"url": f"{console_web_url}/activate?email={invitee_email}&token={token}",
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:

View File

@ -72,7 +72,10 @@ class DefaultModelApi(Resource):
model=model_setting["model"],
)
except Exception as ex:
logging.exception(f"{model_setting['model_type']} save error: {ex}")
logging.exception(
f"Failed to update default model, model type: {model_setting['model_type']},"
f" model:{model_setting.get('model')}"
)
raise ex
return {"result": "success"}
@ -156,7 +159,10 @@ class ModelProviderModelApi(Resource):
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
logging.exception(f"save model credentials error: {ex}")
logging.exception(
f"Failed to save model credentials, tenant_id: {tenant_id},"
f" model: {args.get('model')}, model_type: {args.get('model_type')}"
)
raise ValueError(str(ex))
return {"result": "success"}, 200

View File

@ -7,7 +7,7 @@ from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
@ -549,6 +549,7 @@ class ToolLabelsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
return jsonable_encoder(ToolLabelsService.list_tool_labels())

View File

@ -8,10 +8,10 @@ from flask_login import current_user
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from models.model import DifySetup
from services.feature_service import FeatureService
from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
def account_initialization_required(view):
@ -142,3 +142,15 @@ def setup_required(view):
return view(*args, **kwargs)
return decorated
def enterprise_license_required(view):
@wraps(view)
def decorated(*args, **kwargs):
settings = FeatureService.get_system_features()
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
return view(*args, **kwargs)
return decorated

View File

@ -59,7 +59,7 @@ class AudioApi(WebApiResource):
except ValueError as e:
raise e
except Exception as e:
logging.exception(f"internal server error: {str(e)}")
logging.exception("Failed to handle post request to AudioApi")
raise InternalServerError()
@ -117,7 +117,7 @@ class TextApi(WebApiResource):
except ValueError as e:
raise e
except Exception as e:
logging.exception(f"internal server error: {str(e)}")
logging.exception("Failed to handle post request to TextApi")
raise InternalServerError()

View File

@ -16,9 +16,7 @@ class FileUploadConfigManager:
file_upload_dict = config.get("file_upload")
if file_upload_dict:
if file_upload_dict.get("enabled"):
transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get(
"allowed_upload_methods", []
)
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
data = {
"image_config": {
"number_limits": file_upload_dict["number_limits"],

View File

@ -362,5 +362,5 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(e)
logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}")
raise e

View File

@ -242,7 +242,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
logger.exception(e)
logger.exception(f"Failed to listen audio message, task_id: {task_id}")
break
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)

View File

@ -33,8 +33,8 @@ class BaseAppGenerator:
tenant_id=app_config.tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
@ -47,8 +47,8 @@ class BaseAppGenerator:
tenant_id=app_config.tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
@ -91,6 +91,9 @@ class BaseAppGenerator:
)
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
# handle empty string case
if not value.strip():
return None
# may raise ValueError if user_input_value is not a valid number
try:
if "." in value:

View File

@ -80,7 +80,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(e)
logger.exception(f"Failed to handle response, conversation_id: {conversation.id}")
raise e
def _get_conversation_by_user(

View File

@ -298,5 +298,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(e)
logger.exception(
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
)
raise e

View File

@ -216,7 +216,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
logger.exception(e)
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
break
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)

View File

@ -86,7 +86,7 @@ class MessageCycleManage:
conversation.name = name
except Exception as e:
if dify_config.DEBUG:
logging.exception(f"generate conversation name failed: {e}")
logging.exception(f"generate conversation name failed, conversation_id: {conversation_id}")
pass
db.session.merge(conversation)

View File

@ -28,8 +28,8 @@ class FileUploadConfig(BaseModel):
image_config: Optional[ImageConfig] = None
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_extensions: Sequence[str] = Field(default_factory=list)
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
number_limits: int = 0

View File

@ -41,7 +41,7 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
if moderation_result is True:
return True
except Exception as ex:
logger.exception(ex)
logger.exception(f"Fails to check moderation, provider_name: {provider_name}")
raise InvokeBadRequestError("Rate limit exceeded, please try again later.")
return False

View File

@ -29,7 +29,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
spec.loader.exec_module(module)
return module
except Exception as e:
logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}")
logging.exception(f"Failed to load module {module_name} from script file '{py_file_path}'")
raise e

View File

@ -39,6 +39,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
)
retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:
@ -52,6 +53,8 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST:
if stream:
return response.iter_bytes()
return response
else:
logging.warning(

View File

@ -29,6 +29,7 @@ from core.rag.splitter.fixed_text_splitter import (
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.text_processing_utils import remove_leading_symbols
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -500,11 +501,7 @@ class IndexingRunner:
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
document_node.page_content = remove_leading_symbols(page_content)
if document_node.page_content:
split_documents.append(document_node)
@ -554,7 +551,7 @@ class IndexingRunner:
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception(e)
logging.exception("Failed to format qa document")
all_qa_documents.extend(format_documents)

View File

@ -102,7 +102,7 @@ class LLMGenerator:
except InvokeError:
questions = []
except Exception as e:
logging.exception(e)
logging.exception("Failed to generate suggested questions after answer")
questions = []
return questions
@ -148,7 +148,7 @@ class LLMGenerator:
error = str(e)
error_step = "generate rule config"
except Exception as e:
logging.exception(e)
logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}")
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@ -234,7 +234,7 @@ class LLMGenerator:
error_step = "generate conversation opener"
except Exception as e:
logging.exception(e)
logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}")
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@ -286,7 +286,9 @@ class LLMGenerator:
error = str(e)
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logging.exception(e)
logging.exception(
f"Failed to invoke LLM model, model: {model_config.get('name')}, language: {code_language}"
)
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
@classmethod

View File

@ -325,14 +325,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
assistant_prompt_message.tool_calls.append(tool_call)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
prompt_tokens = (response.usage and response.usage.input_tokens) or self.get_num_tokens(
model, credentials, prompt_messages
)
completion_tokens = (response.usage and response.usage.output_tokens) or self.get_num_tokens(
model, credentials, [assistant_prompt_message]
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)

View File

@ -103,7 +103,7 @@ class AzureRerankModel(RerankModel):
return RerankResult(model=model, docs=rerank_documents)
except Exception as e:
logger.exception(f"Exception in Azure rerank: {e}")
logger.exception(f"Failed to invoke rerank model, model: {model}")
raise
def validate_credentials(self, model: str, credentials: dict) -> None:

View File

@ -2,13 +2,11 @@
import base64
import json
import logging
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
# 3rd import
import boto3
import requests
from botocore.config import Config
from botocore.exceptions import (
ClientError,
@ -439,22 +437,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
url = message_content.data
image_content = requests.get(url).content
if "?" in url:
url = url.split("?")[0]
mime_type, _ = mimetypes.guess_type(url)
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(

View File

@ -11,5 +11,6 @@
- gemini-1.5-flash-exp-0827
- gemini-1.5-flash-8b-exp-0827
- gemini-1.5-flash-8b-exp-0924
- gemini-exp-1114
- gemini-pro
- gemini-pro-vision

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -0,0 +1,38 @@
model: gemini-exp-1114
label:
en_US: Gemini exp 1114
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 2097152
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -32,3 +32,4 @@ pricing:
output: '0.00'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -36,3 +36,4 @@ pricing:
output: '0.00'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -1,7 +1,6 @@
import base64
import io
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
@ -36,17 +35,6 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
""" # noqa: E501
class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(
@ -155,7 +143,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
try:
ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -184,7 +172,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
if schema := config_kwargs.pop("json_schema", None):
try:
schema = json.loads(schema)
except:
raise exceptions.InvalidArgument("Invalid JSON Schema")
if tools:
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
config_kwargs["response_schema"] = schema
config_kwargs["response_mime_type"] = "application/json"
if stop:
config_kwargs["stop_sequences"] = stop

View File

@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, "api/chat")
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
if tools:
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
else:
endpoint_url = urljoin(endpoint_url, "api/generate")
first_prompt_message = prompt_messages[0]
@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if stream:
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)
def _handle_generate_response(
self,
@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
completion_type: LLMMode,
response: requests.Response,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]],
) -> LLMResult:
"""
Handle llm completion response
@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:return: llm result
"""
response_json = response.json()
tool_calls = []
if completion_type is LLMMode.CHAT:
message = response_json.get("message", {})
response_content = message.get("content", "")
response_tool_calls = message.get("tool_calls", [])
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
else:
response_content = response_json["response"]
assistant_message = AssistantPromptMessage(content=response_content)
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)
if "prompt_eval_count" in response_json and "eval_count" in response_json:
# transform usage
@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
chunk_index += 1
def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
"""
Convert PromptMessageTool to dict for Ollama API
:param tool: tool
:return: tool dict
"""
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for Ollama API
:param message: prompt message
:return: message dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return num_tokens
def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
"""
Extract response tool call
"""
tool_call = None
if response_tool_call and "function" in response_tool_call:
# Convert arguments to JSON string if it's a dict
arguments = response_tool_call.get("function").get("arguments")
if isinstance(arguments, dict):
arguments = json.dumps(arguments)
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.get("function").get("name"),
arguments=arguments,
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.get("function").get("name"),
type="function",
function=function,
)
return tool_call
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
Get customizable model schema.
@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:return: model schema
"""
extras = {}
extras = {
"features": [],
}
if "vision_support" in credentials and credentials["vision_support"] == "true":
extras["features"] = [ModelFeature.VISION]
extras["features"].append(ModelFeature.VISION)
if "function_call_support" in credentials and credentials["function_call_support"] == "true":
extras["features"].append(ModelFeature.TOOL_CALL)
extras["features"].append(ModelFeature.MULTI_TOOL_CALL)
entity = AIModelEntity(
model=model,

View File

@ -96,3 +96,22 @@ model_credential_schema:
label:
en_US: 'No'
zh_Hans:
- variable: function_call_support
label:
zh_Hans: 是否支持函数调用
en_US: Function call support
show_on:
- variable: __model_type
value: llm
default: 'false'
type: radio
required: false
options:
- value: 'true'
label:
en_US: 'Yes'
zh_Hans:
- value: 'false'
label:
en_US: 'No'
zh_Hans:

View File

@ -615,19 +615,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
# o1 compatibility
block_as_stream = False
if model.startswith("o1"):
if "max_tokens" in model_parameters:
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
del model_parameters["max_tokens"]
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"]
@ -644,47 +636,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:param stop: stop words
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
finish_reason="stop",
usage=block_result.usage,
),
)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
def _handle_chat_generate_response(
self,

View File

@ -45,19 +45,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._update_credential(model, credentials)
block_as_stream = False
if model.startswith("openai/o1"):
block_as_stream = True
stop = None
# invoke block as stream
if stream and block_as_stream:
return self._generate_block_as_stream(
model, credentials, prompt_messages, model_parameters, tools, stop, user
)
else:
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _generate_block_as_stream(
self,
@ -69,9 +57,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
stop: Optional[list[str]] = None,
user: Optional[str] = None,
) -> Generator:
resp: LLMResult = super()._generate(
model, credentials, prompt_messages, model_parameters, tools, stop, False, user
)
resp = super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, False, user)
yield LLMResultChunk(
model=model,

View File

@ -113,7 +113,7 @@ class SageMakerRerankModel(RerankModel):
return RerankResult(model=model, docs=rerank_documents)
except Exception as e:
logger.exception(f"Exception {e}, line : {line}")
logger.exception(f"Failed to invoke rerank model, model: {model}")
def validate_credentials(self, model: str, credentials: dict) -> None:
"""

View File

@ -78,7 +78,7 @@ class SageMakerSpeech2TextModel(Speech2TextModel):
json_obj = json.loads(json_str)
asr_text = json_obj["text"]
except Exception as e:
logger.exception(f"failed to invoke speech2text model, {e}")
logger.exception(f"failed to invoke speech2text model, model: {model}")
raise CredentialsValidateFailedError(str(e))
return asr_text

View File

@ -117,7 +117,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel):
return TextEmbeddingResult(embeddings=all_embeddings, usage=usage, model=model)
except Exception as e:
logger.exception(f"Exception {e}, line : {line}")
logger.exception(f"Failed to invoke text embedding model, model: {model}, line: {line}")
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""

View File

@ -65,6 +65,8 @@ class GTERerankModel(RerankModel):
)
rerank_documents = []
if not response.output:
return RerankResult(model=model, docs=rerank_documents)
for _, result in enumerate(response.output.results):
# format document
rerank_document = RerankDocument(

View File

@ -1,3 +1,6 @@
from collections.abc import Sequence
from typing import Any
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
@ -62,5 +65,5 @@ class KeywordsModeration(Moderation):
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
def _check_keywords_in_value(self, keywords_list, value) -> bool:
return any(keyword.lower() in value.lower() for keyword in keywords_list)
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
return any(keyword.lower() in str(value).lower() for keyword in keywords_list)

View File

@ -126,6 +126,6 @@ class OutputModeration(BaseModel):
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
logger.exception("Moderation Output error: %s", e)
logger.exception(f"Moderation Output error, app_id: {app_id}")
return None

View File

@ -49,6 +49,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
@field_validator("inputs", "outputs")
@classmethod

View File

@ -25,7 +25,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunType,
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values
from core.ops.utils import filter_none_values, generate_dotted_order
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
@ -62,6 +62,16 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_app_log_id or trace_info.workflow_run_id
message_dotted_order = (
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
)
workflow_dotted_order = generate_dotted_order(
trace_info.workflow_app_log_id or trace_info.workflow_run_id,
trace_info.workflow_data.created_at,
message_dotted_order,
)
if trace_info.message_id:
message_run = LangSmithRunModel(
id=trace_info.message_id,
@ -76,6 +86,8 @@ class LangSmithDataTrace(BaseTraceInstance):
},
tags=["message", "workflow"],
error=trace_info.error,
trace_id=trace_id,
dotted_order=message_dotted_order,
)
self.add_run(message_run)
@ -95,6 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance):
error=trace_info.error,
tags=["workflow"],
parent_run_id=trace_info.message_id or None,
trace_id=trace_id,
dotted_order=workflow_dotted_order,
)
self.add_run(langsmith_run)
@ -177,6 +191,7 @@ class LangSmithDataTrace(BaseTraceInstance):
else:
run_type = LangSmithRunType.tool
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
langsmith_run = LangSmithRunModel(
total_tokens=node_total_tokens,
name=node_type,
@ -191,6 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance):
},
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
tags=["node_execution"],
id=node_execution_id,
trace_id=trace_id,
dotted_order=node_dotted_order,
)
self.add_run(langsmith_run)

View File

@ -711,7 +711,7 @@ class TraceQueueManager:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception as e:
logging.exception(f"Error adding trace task: {e}")
logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
finally:
self.start_timer()
@ -730,7 +730,7 @@ class TraceQueueManager:
if tasks:
self.send_to_celery(tasks)
except Exception as e:
logging.exception(f"Error processing trace tasks: {e}")
logging.exception("Error processing trace tasks")
def start_timer(self):
global trace_manager_timer

View File

@ -1,5 +1,6 @@
from contextlib import contextmanager
from datetime import datetime
from typing import Optional, Union
from extensions.ext_database import db
from models.model import Message
@ -43,3 +44,19 @@ def replace_text_with_content(data):
return [replace_text_with_content(item) for item in data]
else:
return data
def generate_dotted_order(
run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None
) -> str:
"""
generate dotted_order for langsmith
"""
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
current_segment = f"{timestamp}{run_id}"
if parent_dotted_order is None:
return current_segment
return f"{parent_dotted_order}.{current_segment}"

View File

@ -1,310 +1,62 @@
import json
from typing import Any
from pydantic import BaseModel
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from configs import dify_config
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class AnalyticdbConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = ("dify",)
namespace_password: str = (None,)
metrics: str = ("cosine",)
read_timeout: int = 60000
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVector(BaseVector):
def __init__(self, collection_name: str, config: AnalyticdbConfig):
self._collection_name = collection_name.lower()
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
except:
raise ImportError(_import_err_msg)
self.config = config
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
def _initialize(self) -> None:
cache_key = f"vector_indexing_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}"
if redis_client.get(collection_exist_cache_key):
return
self._initialize_vector_database()
self._create_namespace_if_not_exists()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else:
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def __init__(
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
):
super().__init__(collection_name)
if api_config is not None:
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
else:
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
def get_type(self) -> str:
return VectorType.ANALYTICDB
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection_if_not_exists(dimension)
self.add_texts(texts, embeddings)
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(texts, embeddings)
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
return self.analyticdb_vector.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
self.analyticdb_vector.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
self.analyticdb_vector.delete_by_metadata_field(key, value)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
return self.analyticdb_vector.search_by_vector(query_vector)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.metadata.get("vector"),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
except Exception as e:
raise e
self.analyticdb_vector.delete()
class AnalyticdbVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
# handle optional params
if dify_config.ANALYTICDB_KEY_ID is None:
raise ValueError("ANALYTICDB_KEY_ID should not be None")
if dify_config.ANALYTICDB_KEY_SECRET is None:
raise ValueError("ANALYTICDB_KEY_SECRET should not be None")
if dify_config.ANALYTICDB_REGION_ID is None:
raise ValueError("ANALYTICDB_REGION_ID should not be None")
if dify_config.ANALYTICDB_INSTANCE_ID is None:
raise ValueError("ANALYTICDB_INSTANCE_ID should not be None")
if dify_config.ANALYTICDB_ACCOUNT is None:
raise ValueError("ANALYTICDB_ACCOUNT should not be None")
if dify_config.ANALYTICDB_PASSWORD is None:
raise ValueError("ANALYTICDB_PASSWORD should not be None")
if dify_config.ANALYTICDB_NAMESPACE is None:
raise ValueError("ANALYTICDB_NAMESPACE should not be None")
if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None:
raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None")
return AnalyticdbVector(
collection_name,
AnalyticdbConfig(
if dify_config.ANALYTICDB_HOST is None:
# implemented through OpenAPI
apiConfig = AnalyticdbVectorOpenAPIConfig(
access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=dify_config.ANALYTICDB_REGION_ID,
@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
),
)
sqlConfig = None
else:
# implemented through sql
sqlConfig = AnalyticdbVectorBySqlConfig(
host=dify_config.ANALYTICDB_HOST,
port=dify_config.ANALYTICDB_PORT,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
namespace=dify_config.ANALYTICDB_NAMESPACE,
)
apiConfig = None
return AnalyticdbVector(
collection_name,
apiConfig,
sqlConfig,
)

View File

@ -0,0 +1,309 @@
import json
from typing import Any
from pydantic import BaseModel, model_validator
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorOpenAPIConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = "dify"
namespace_password: str = (None,)
metrics: str = "cosine"
read_timeout: int = 60000
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["access_key_id"]:
raise ValueError("config ANALYTICDB_KEY_ID is required")
if not values["access_key_secret"]:
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
if not values["region_id"]:
raise ValueError("config ANALYTICDB_REGION_ID is required")
if not values["instance_id"]:
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["namespace_password"]:
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
return values
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVectorOpenAPI:
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
except:
raise ImportError(_import_err_msg)
self._collection_name = collection_name.lower()
self.config = config
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
def _initialize(self) -> None:
cache_key = f"vector_initialize_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
self._create_namespace_if_not_exists()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else:
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
except Exception as e:
raise e

View File

@ -0,0 +1,245 @@
import json
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorBySqlConfig(BaseModel):
host: str
port: int
account: str
account_password: str
min_connection: int
max_connection: int
namespace: str = "dify"
metrics: str = "cosine"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config ANALYTICDB_HOST is required")
if not values["port"]:
raise ValueError("config ANALYTICDB_PORT is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["min_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
if not values["max_connection"]:
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
if values["min_connection"] > values["max_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
return values
class AnalyticdbVectorBySql:
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
self._collection_name = collection_name.lower()
self.databaseName = "knowledgebase"
self.config = config
self.table_name = f"{self.config.namespace}.{self._collection_name}"
self.pool = None
self._initialize()
if not self.pool:
self.pool = self._create_connection_pool()
def _initialize(self) -> None:
cache_key = f"vector_initialize_{self.config.host}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.host}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _create_connection_pool(self):
return psycopg2.pool.SimpleConnectionPool(
self.config.min_connection,
self.config.max_connection,
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database=self.databaseName,
)
@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)
def _initialize_vector_database(self) -> None:
conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database="postgres",
)
conn.autocommit = True
cur = conn.cursor()
try:
cur.execute(f"CREATE DATABASE {self.databaseName}")
except Exception as e:
if "already exists" in str(e):
return
raise e
finally:
cur.close()
conn.close()
self.pool = self._create_connection_pool()
with self._get_cursor() as cur:
try:
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
except Exception as e:
if "already exists" not in str(e):
raise e
cur.execute(
"CREATE OR REPLACE FUNCTION "
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
"AS words_only;$function$"
)
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute(
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
f"id text PRIMARY KEY,"
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
f"to_tsvector TSVECTOR"
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
)
if embedding_dimension is not None:
index_name = f"{self._collection_name}_embedding_idx"
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
cur.execute(
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
f"pq_enable=0, external_storage=0)"
)
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
id_prefix = str(uuid.uuid4()) + "_"
sql = f"""
INSERT INTO {self.table_name}
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_batch(cur, sql, values)
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
return cur.fetchone() is not None
def delete_by_ids(self, ids: list[str]) -> None:
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
except Exception as e:
if "does not exist" not in str(e):
raise e
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
except Exception as e:
if "does not exist" not in str(e):
raise e
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
with self._get_cursor() as cur:
query_vector_str = json.dumps(query_vector)
query_vector_str = "{" + query_vector_str[1:-1] + "}"
cur.execute(
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
f"t.page_content as page_content, t.metadata_ AS metadata_ "
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
(query_vector_str,),
)
documents = []
for record in cur:
id, vector, score, page_content, metadata = record
if score > score_threshold:
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
with self._get_cursor() as cur:
cur.execute(
f"""SELECT id, vector, page_content, metadata_,
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
ORDER BY score DESC
LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"),
)
documents = []
for record in cur:
id, vector, page_content, metadata, score = record
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def delete(self) -> None:
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@ -242,7 +242,7 @@ class CouchbaseVector(BaseVector):
try:
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
except Exception as e:
logger.exception(e)
logger.exception(f"Failed to delete documents, ids: {ids}")
def delete_by_document_id(self, document_id: str):
query = f"""

View File

@ -81,7 +81,7 @@ class LindormVectorStore(BaseVector):
"ids": batch_ids}, _source=False)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch {batch_ids}: {e}")
logger.exception(f"Error fetching batch {batch_ids}")
return set()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
@ -99,7 +99,7 @@ class LindormVectorStore(BaseVector):
)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch {batch_ids}: {e}")
logger.exception(f"Error fetching batch ids: {batch_ids}")
return set()
if ids is None:
@ -187,7 +187,7 @@ class LindormVectorStore(BaseVector):
logger.warning(
f"Index '{self._collection_name}' does not exist. No deletion performed.")
except Exception as e:
logger.exception(f"Error occurred while deleting the index: {e}")
logger.exception(f"Error occurred while deleting the index: {self._collection_name}")
raise e
def text_exists(self, id: str) -> bool:
@ -213,7 +213,7 @@ class LindormVectorStore(BaseVector):
response = self._client.search(
index=self._collection_name, body=query)
except Exception as e:
logger.exception(f"Error executing search: {e}")
logger.exception(f"Error executing vector search, query: {query}")
raise
docs_and_scores = []

View File

@ -142,7 +142,7 @@ class MyScaleVector(BaseVector):
for r in self._client.query(sql).named_results()
]
except Exception as e:
logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") # noqa:TRY401
return []
def delete(self) -> None:

View File

@ -158,7 +158,7 @@ class OpenSearchVector(BaseVector):
try:
response = self._client.search(index=self._collection_name.lower(), body=query)
except Exception as e:
logger.exception(f"Error executing search: {e}")
logger.exception(f"Error executing vector search, query: {query}")
raise
docs = []

View File

@ -69,7 +69,7 @@ class CacheEmbedding(Embeddings):
except IntegrityError:
db.session.rollback()
except Exception as e:
logging.exception("Failed transform embedding: %s", e)
logging.exception("Failed transform embedding")
cache_embeddings = []
try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
@ -89,7 +89,7 @@ class CacheEmbedding(Embeddings):
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.exception("Failed to embed documents: %s", ex)
logger.exception("Failed to embed documents: %s")
raise ex
return text_embeddings
@ -112,7 +112,7 @@ class CacheEmbedding(Embeddings):
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
except Exception as ex:
if dify_config.DEBUG:
logging.exception(f"Failed to embed query text: {ex}")
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")
raise ex
try:
@ -126,7 +126,7 @@ class CacheEmbedding(Embeddings):
redis_client.setex(embedding_cache_key, 600, encoded_str)
except Exception as ex:
if dify_config.DEBUG:
logging.exception("Failed to add embedding to redis %s", ex)
logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'")
raise ex
return embedding_results

View File

@ -229,7 +229,7 @@ class WordExtractor(BaseExtractor):
for i in url_pattern.findall(x.text):
hyperlinks_url = str(i)
except Exception as e:
logger.exception(e)
logger.exception("Failed to parse HYPERLINK xml")
def parse_paragraph(paragraph):
paragraph_content = []

View File

@ -11,6 +11,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
@ -43,11 +44,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:].strip()
else:
page_content = page_content
page_content = remove_leading_symbols(document_node.page_content).strip()
if len(page_content) > 0:
document_node.page_content = page_content
split_documents.append(document_node)

View File

@ -18,6 +18,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
@ -53,11 +54,7 @@ class QAIndexProcessor(BaseIndexProcessor):
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
document_node.page_content = remove_leading_symbols(page_content)
split_documents.append(document_node)
all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10):
@ -159,7 +156,7 @@ class QAIndexProcessor(BaseIndexProcessor):
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception(e)
logging.exception("Failed to format qa document")
all_qa_documents.extend(format_documents)

View File

@ -36,23 +36,21 @@ class WeightRerankRunner(BaseRerankRunner):
:return:
"""
docs = []
doc_id = []
unique_documents = []
doc_id = set()
for document in documents:
if document.metadata["doc_id"] not in doc_id:
doc_id.append(document.metadata["doc_id"])
docs.append(document.page_content)
doc_id = document.metadata.get("doc_id")
if doc_id not in doc_id:
doc_id.add(doc_id)
unique_documents.append(document)
documents = unique_documents
rerank_documents = []
query_scores = self._calculate_keyword_score(query, documents)
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
rerank_documents = []
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
# format document
score = (
self.weights.vector_setting.vector_weight * query_vector_score
+ self.weights.keyword_setting.keyword_weight * query_score
@ -61,7 +59,8 @@ class WeightRerankRunner(BaseRerankRunner):
continue
document.metadata["score"] = score
rerank_documents.append(document)
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True)
rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:

View File

@ -0,0 +1,87 @@
from typing import Any
from duckduckgo_search import DDGS
from core.model_runtime.entities.message_entities import SystemPromptMessage
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
SUMMARY_PROMPT = """
User's query:
{query}
Here are the news results:
{content}
Please summarize the news in a few sentences.
"""
class DuckDuckGoNewsSearchTool(BuiltinTool):
"""
Tool for performing a news search using DuckDuckGo search engine.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
query_dict = {
"keywords": tool_parameters.get("query"),
"timelimit": tool_parameters.get("timelimit"),
"max_results": tool_parameters.get("max_results"),
"safesearch": "moderate",
"region": "wt-wt",
}
try:
response = list(DDGS().news(**query_dict))
if not response:
return [self.create_text_message("No news found matching your criteria.")]
except Exception as e:
return [self.create_text_message(f"Error searching news: {str(e)}")]
require_summary = tool_parameters.get("require_summary", False)
if require_summary:
results = "\n".join([f"{res.get('title')}: {res.get('body')}" for res in response])
results = self.summary_results(user_id=user_id, content=results, query=query_dict["keywords"])
return self.create_text_message(text=results)
# Create rich markdown content for each news item
markdown_result = "\n\n"
json_result = []
for res in response:
markdown_result += f"### {res.get('title', 'Untitled')}\n\n"
if res.get("date"):
markdown_result += f"**Date:** {res.get('date')}\n\n"
if res.get("body"):
markdown_result += f"{res.get('body')}\n\n"
if res.get("source"):
markdown_result += f"*Source: {res.get('source')}*\n\n"
if res.get("image"):
markdown_result += f"![{res.get('title', '')}]({res.get('image')})\n\n"
markdown_result += f"[Read more]({res.get('url', '')})\n\n---\n\n"
json_result.append(
self.create_json_message(
{
"title": res.get("title", ""),
"date": res.get("date", ""),
"body": res.get("body", ""),
"url": res.get("url", ""),
"image": res.get("image", ""),
"source": res.get("source", ""),
}
)
)
return [self.create_text_message(markdown_result)] + json_result
def summary_results(self, user_id: str, content: str, query: str) -> str:
prompt = SUMMARY_PROMPT.format(query=query, content=content)
summary = self.invoke_model(
user_id=user_id,
prompt_messages=[
SystemPromptMessage(content=prompt),
],
stop=[],
)
return summary.message.content

View File

@ -0,0 +1,71 @@
identity:
name: ddgo_news
author: Assistant
label:
en_US: DuckDuckGo News Search
zh_Hans: DuckDuckGo 新闻搜索
description:
human:
en_US: Perform news searches on DuckDuckGo and get results.
zh_Hans: 在 DuckDuckGo 上进行新闻搜索并获取结果。
llm: Perform news searches on DuckDuckGo and get results.
parameters:
- name: query
type: string
required: true
label:
en_US: Query String
zh_Hans: 查询语句
human_description:
en_US: Search Query.
zh_Hans: 搜索查询语句。
llm_description: Key words for searching
form: llm
- name: max_results
type: number
required: true
default: 5
label:
en_US: Max Results
zh_Hans: 最大结果数量
human_description:
en_US: The Max Results
zh_Hans: 最大结果数量
form: form
- name: timelimit
type: select
required: false
options:
- value: Day
label:
en_US: Current Day
zh_Hans: 当天
- value: Week
label:
en_US: Current Week
zh_Hans: 本周
- value: Month
label:
en_US: Current Month
zh_Hans: 当月
- value: Year
label:
en_US: Current Year
zh_Hans: 今年
label:
en_US: Result Time Limit
zh_Hans: 结果时间限制
human_description:
en_US: Use when querying results within a specific time range only.
zh_Hans: 只查询一定时间范围内的结果时使用
form: form
- name: require_summary
type: boolean
default: false
label:
en_US: Require Summary
zh_Hans: 是否总结
human_description:
en_US: Whether to pass the news results to llm for summarization.
zh_Hans: 是否需要将新闻结果传给大模型总结
form: form

View File

@ -0,0 +1,75 @@
from typing import Any, ClassVar
from duckduckgo_search import DDGS
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class DuckDuckGoVideoSearchTool(BuiltinTool):
"""
Tool for performing a video search using DuckDuckGo search engine.
"""
IFRAME_TEMPLATE: ClassVar[str] = """
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
max-width: 100%; border-radius: 8px;">
<iframe
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
src="{src}"
frameborder="0"
allowfullscreen>
</iframe>
</div>"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
query_dict = {
"keywords": tool_parameters.get("query"),
"region": tool_parameters.get("region", "wt-wt"),
"safesearch": tool_parameters.get("safesearch", "moderate"),
"timelimit": tool_parameters.get("timelimit"),
"resolution": tool_parameters.get("resolution"),
"duration": tool_parameters.get("duration"),
"license_videos": tool_parameters.get("license_videos"),
"max_results": tool_parameters.get("max_results"),
}
# Remove None values to use API defaults
query_dict = {k: v for k, v in query_dict.items() if v is not None}
# Get proxy URL from parameters
proxy_url = tool_parameters.get("proxy_url", "").strip()
response = DDGS().videos(**query_dict)
# Create HTML result with embedded iframes
markdown_result = "\n\n"
json_result = []
for res in response:
title = res.get("title", "")
embed_html = res.get("embed_html", "")
description = res.get("description", "")
content_url = res.get("content", "")
# Handle TED.com videos
if not embed_html and "ted.com/talks" in content_url:
embed_url = content_url.replace("www.ted.com", "embed.ted.com")
if proxy_url:
embed_url = f"{proxy_url}{embed_url}"
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
# Original YouTube/other platform handling
elif embed_html:
embed_url = res.get("embed_url", "")
if proxy_url and embed_url:
embed_url = f"{proxy_url}{embed_url}"
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
markdown_result += f"{title}\n\n"
markdown_result += f"{embed_html}\n\n"
markdown_result += "---\n\n"
json_result.append(self.create_json_message(res))
return [self.create_text_message(markdown_result)] + json_result

View File

@ -0,0 +1,97 @@
identity:
name: ddgo_video
author: Tao Wang
label:
en_US: DuckDuckGo Video Search
zh_Hans: DuckDuckGo 视频搜索
description:
human:
en_US: Search and embedded videos.
zh_Hans: 搜索并嵌入视频
llm: Search videos on duckduckgo and embed videos in iframe
parameters:
- name: query
label:
en_US: Query String
zh_Hans: 查询语句
type: string
required: true
human_description:
en_US: Search Query
zh_Hans: 搜索查询语句
llm_description: Key words for searching
form: llm
- name: max_results
label:
en_US: Max Results
zh_Hans: 最大结果数量
type: number
required: true
default: 3
minimum: 1
maximum: 10
human_description:
en_US: The max results (1-10)
zh_Hans: 最大结果数量1-10
form: form
- name: timelimit
label:
en_US: Result Time Limit
zh_Hans: 结果时间限制
type: select
required: false
options:
- value: Day
label:
en_US: Current Day
zh_Hans: 当天
- value: Week
label:
en_US: Current Week
zh_Hans: 本周
- value: Month
label:
en_US: Current Month
zh_Hans: 当月
- value: Year
label:
en_US: Current Year
zh_Hans: 今年
human_description:
en_US: Query results within a specific time range only
zh_Hans: 只查询一定时间范围内的结果时使用
form: form
- name: duration
label:
en_US: Video Duration
zh_Hans: 视频时长
type: select
required: false
options:
- value: short
label:
en_US: Short (<4 minutes)
zh_Hans: 短视频(<4分钟
- value: medium
label:
en_US: Medium (4-20 minutes)
zh_Hans: 中等4-20分钟
- value: long
label:
en_US: Long (>20 minutes)
zh_Hans: 长视频(>20分钟
human_description:
en_US: Filter videos by duration
zh_Hans: 按时长筛选视频
form: form
- name: proxy_url
label:
en_US: Proxy URL
zh_Hans: 视频代理地址
type: string
required: false
default: ""
human_description:
en_US: Proxy URL
zh_Hans: 视频代理地址
form: form

View File

@ -38,7 +38,7 @@ def send_mail(parmas: SendEmailToolParameters):
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
return True
except Exception as e:
logging.exception("send email failed: %s", e)
logging.exception("send email failed")
return False
else: # NONE or TLS
try:
@ -49,5 +49,5 @@ def send_mail(parmas: SendEmailToolParameters):
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
return True
except Exception as e:
logging.exception("send email failed: %s", e)
logging.exception("send email failed")
return False

View File

@ -17,7 +17,7 @@ class SendMailTool(BuiltinTool):
invoke tools
"""
sender = self.runtime.credentials.get("email_account", "")
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
password = self.runtime.credentials.get("email_password", "")
smtp_server = self.runtime.credentials.get("smtp_server", "")
if not smtp_server:

View File

@ -18,7 +18,7 @@ class SendMailTool(BuiltinTool):
invoke tools
"""
sender = self.runtime.credentials.get("email_account", "")
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
password = self.runtime.credentials.get("email_password", "")
smtp_server = self.runtime.credentials.get("smtp_server", "")
if not smtp_server:

View File

@ -19,7 +19,7 @@ class WizperTool(BuiltinTool):
version = tool_parameters.get("version", "3")
if audio_file.type != FileType.AUDIO:
return [self.create_text_message("Not a valid audio file.")]
return self.create_text_message("Not a valid audio file.")
api_key = self.runtime.credentials["fal_api_key"]
@ -31,9 +31,8 @@ class WizperTool(BuiltinTool):
try:
audio_url = fal_client.upload(file_data, mime_type)
except Exception as e:
return [self.create_text_message(f"Error uploading audio file: {str(e)}")]
return self.create_text_message(f"Error uploading audio file: {str(e)}")
arguments = {
"audio_url": audio_url,
@ -49,4 +48,9 @@ class WizperTool(BuiltinTool):
with_logs=False,
)
return self.create_json_message(result)
json_message = self.create_json_message(result)
text = result.get("text", "")
text_message = self.create_text_message(text)
return [json_message, text_message]

View File

@ -0,0 +1,25 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GiteeAIToolEmbedding(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
headers = {
"content-type": "application/json",
"authorization": f"Bearer {self.runtime.credentials['api_key']}",
}
payload = {"inputs": tool_parameters.get("inputs")}
model = tool_parameters.get("model", "bge-m3")
url = f"https://ai.gitee.com/api/serverless/{model}/embeddings"
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
return self.create_text_message(f"Got Error Response:{response.text}")
return [self.create_text_message(response.content.decode("utf-8"))]

View File

@ -0,0 +1,37 @@
identity:
name: embedding
author: gitee_ai
label:
en_US: embedding
icon: icon.svg
description:
human:
en_US: Generate word embeddings using Serverless-supported models (compatible with OpenAI)
llm: This tool is used to generate word embeddings from text input.
parameters:
- name: model
type: string
required: true
in: path
description:
en_US: Supported Embedding (compatible with OpenAI) interface models
enum:
- bge-m3
- bge-large-zh-v1.5
- bge-small-zh-v1.5
label:
en_US: Service Model
zh_Hans: 服务模型
default: bge-m3
form: form
- name: inputs
type: string
required: true
label:
en_US: Input Text
zh_Hans: 输入文本
human_description:
en_US: The text input used to generate embeddings.
zh_Hans: 用于生成词向量的输入文本。
llm_description: This text input will be used to generate embeddings.
form: llm

View File

@ -6,7 +6,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GiteeAITool(BuiltinTool):
class GiteeAIToolText2Image(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:

View File

@ -1,14 +1,12 @@
identity:
author: Yash Parmar
author: Yash Parmar, Kalo Chin
name: tavily
label:
en_US: Tavily
zh_Hans: Tavily
pt_BR: Tavily
en_US: Tavily Search & Extract
zh_Hans: Tavily 搜索和提取
description:
en_US: Tavily
zh_Hans: Tavily
pt_BR: Tavily
en_US: A powerful AI-native search engine and web content extraction tool that provides highly relevant search results and raw content extraction from web pages.
zh_Hans: 一个强大的原生AI搜索引擎和网页内容提取工具提供高度相关的搜索结果和网页原始内容提取。
icon: icon.png
tags:
- search
@ -19,13 +17,10 @@ credentials_for_provider:
label:
en_US: Tavily API key
zh_Hans: Tavily API key
pt_BR: Tavily API key
placeholder:
en_US: Please input your Tavily API key
zh_Hans: 请输入你的 Tavily API key
pt_BR: Please input your Tavily API key
help:
en_US: Get your Tavily API key from Tavily
zh_Hans: 从 TavilyApi 获取您的 Tavily API key
pt_BR: Get your Tavily API key from Tavily
url: https://docs.tavily.com/docs/welcome
url: https://app.tavily.com/home

View File

@ -0,0 +1,145 @@
from typing import Any
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
TAVILY_API_URL = "https://api.tavily.com"
class TavilyExtract:
"""
A class for extracting content from web pages using the Tavily Extract API.
Args:
api_key (str): The API key for accessing the Tavily Extract API.
Methods:
extract_content: Retrieves extracted content from the Tavily Extract API.
"""
def __init__(self, api_key: str) -> None:
self.api_key = api_key
def extract_content(self, params: dict[str, Any]) -> dict:
"""
Retrieves extracted content from the Tavily Extract API.
Args:
params (Dict[str, Any]): The extraction parameters.
Returns:
dict: The extracted content.
"""
# Ensure required parameters are set
if "api_key" not in params:
params["api_key"] = self.api_key
# Process parameters
processed_params = self._process_params(params)
response = requests.post(f"{TAVILY_API_URL}/extract", json=processed_params)
response.raise_for_status()
return response.json()
def _process_params(self, params: dict[str, Any]) -> dict:
"""
Processes and validates the extraction parameters.
Args:
params (Dict[str, Any]): The extraction parameters.
Returns:
dict: The processed parameters.
"""
processed_params = {}
# Process 'urls'
if "urls" in params:
urls = params["urls"]
if isinstance(urls, str):
processed_params["urls"] = [url.strip() for url in urls.replace(",", " ").split()]
elif isinstance(urls, list):
processed_params["urls"] = urls
else:
raise ValueError("The 'urls' parameter is required.")
# Only include 'api_key'
processed_params["api_key"] = params.get("api_key", self.api_key)
return processed_params
class TavilyExtractTool(BuiltinTool):
"""
A tool for extracting content from web pages using Tavily Extract.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
Invokes the Tavily Extract tool with the given user ID and tool parameters.
Args:
user_id (str): The ID of the user invoking the tool.
tool_parameters (Dict[str, Any]): The parameters for the Tavily Extract tool.
Returns:
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the Tavily Extract tool invocation.
"""
urls = tool_parameters.get("urls", "")
api_key = self.runtime.credentials.get("tavily_api_key")
if not api_key:
return self.create_text_message(
"Tavily API key is missing. Please set the 'tavily_api_key' in credentials."
)
if not urls:
return self.create_text_message("Please input at least one URL to extract.")
tavily_extract = TavilyExtract(api_key)
try:
raw_results = tavily_extract.extract_content(tool_parameters)
except requests.HTTPError as e:
return self.create_text_message(f"Error occurred while extracting content: {str(e)}")
if not raw_results.get("results"):
return self.create_text_message("No content could be extracted from the provided URLs.")
else:
# Always return JSON message with all data
json_message = self.create_json_message(raw_results)
# Create text message based on user-selected parameters
text_message_content = self._format_results_as_text(raw_results)
text_message = self.create_text_message(text=text_message_content)
return [json_message, text_message]
def _format_results_as_text(self, raw_results: dict) -> str:
"""
Formats the raw extraction results into a markdown text based on user-selected parameters.
Args:
raw_results (dict): The raw extraction results.
Returns:
str: The formatted markdown text.
"""
output_lines = []
for idx, result in enumerate(raw_results.get("results", []), 1):
url = result.get("url", "")
raw_content = result.get("raw_content", "")
output_lines.append(f"## Extracted Content {idx}: {url}\n")
output_lines.append(f"**Raw Content:**\n{raw_content}\n")
output_lines.append("---\n")
if raw_results.get("failed_results"):
output_lines.append("## Failed URLs:\n")
for failed in raw_results["failed_results"]:
url = failed.get("url", "")
error = failed.get("error", "Unknown error")
output_lines.append(f"- {url}: {error}\n")
return "\n".join(output_lines)

View File

@ -0,0 +1,23 @@
identity:
name: tavily_extract
author: Kalo Chin
label:
en_US: Tavily Extract
zh_Hans: Tavily Extract
description:
human:
en_US: A web extraction tool built specifically for AI agents (LLMs), delivering raw content from web pages.
zh_Hans: 专为人工智能代理 (LLM) 构建的网页提取工具,提供网页的原始内容。
llm: A tool for extracting raw content from web pages, designed for AI agents (LLMs).
parameters:
- name: urls
type: string
required: true
label:
en_US: URLs
zh_Hans: URLs
human_description:
en_US: A comma-separated list of URLs to extract content from.
zh_Hans: 要从中提取内容的 URL 的逗号分隔列表。
llm_description: A comma-separated list of URLs to extract content from.
form: llm

View File

@ -17,8 +17,6 @@ class TavilySearch:
Methods:
raw_results: Retrieves raw search results from the Tavily Search API.
results: Retrieves cleaned search results from the Tavily Search API.
clean_results: Cleans the raw search results.
"""
def __init__(self, api_key: str) -> None:
@ -35,63 +33,62 @@ class TavilySearch:
dict: The raw search results.
"""
# Ensure required parameters are set
params["api_key"] = self.api_key
if (
"exclude_domains" in params
and isinstance(params["exclude_domains"], str)
and params["exclude_domains"] != "None"
):
params["exclude_domains"] = params["exclude_domains"].split()
else:
params["exclude_domains"] = []
if (
"include_domains" in params
and isinstance(params["include_domains"], str)
and params["include_domains"] != "None"
):
params["include_domains"] = params["include_domains"].split()
else:
params["include_domains"] = []
response = requests.post(f"{TAVILY_API_URL}/search", json=params)
# Process parameters to ensure correct types
processed_params = self._process_params(params)
response = requests.post(f"{TAVILY_API_URL}/search", json=processed_params)
response.raise_for_status()
return response.json()
def results(self, params: dict[str, Any]) -> list[dict]:
def _process_params(self, params: dict[str, Any]) -> dict:
"""
Retrieves cleaned search results from the Tavily Search API.
Processes and validates the search parameters.
Args:
params (Dict[str, Any]): The search parameters.
Returns:
list: The cleaned search results.
dict: The processed parameters.
"""
raw_search_results = self.raw_results(params)
return self.clean_results(raw_search_results["results"])
processed_params = {}
def clean_results(self, results: list[dict]) -> list[dict]:
"""
Cleans the raw search results.
for key, value in params.items():
if value is None or value == "None":
continue
if key in ["include_domains", "exclude_domains"]:
if isinstance(value, str):
# Split the string by commas or spaces and strip whitespace
processed_params[key] = [domain.strip() for domain in value.replace(",", " ").split()]
elif key in ["include_images", "include_image_descriptions", "include_answer", "include_raw_content"]:
# Ensure boolean type
if isinstance(value, str):
processed_params[key] = value.lower() == "true"
else:
processed_params[key] = bool(value)
elif key in ["max_results", "days"]:
if isinstance(value, str):
processed_params[key] = int(value)
else:
processed_params[key] = value
elif key in ["search_depth", "topic", "query", "api_key"]:
processed_params[key] = value
else:
# Unrecognized parameter
pass
Args:
results (list): The raw search results.
# Set defaults if not present
processed_params.setdefault("search_depth", "basic")
processed_params.setdefault("topic", "general")
processed_params.setdefault("max_results", 5)
Returns:
list: The cleaned search results.
# If topic is 'news', ensure 'days' is set
if processed_params.get("topic") == "news":
processed_params.setdefault("days", 3)
"""
clean_results = []
for result in results:
clean_results.append(
{
"url": result["url"],
"content": result["content"],
}
)
# return clean results as a string
return "\n".join([f"{res['url']}\n{res['content']}" for res in clean_results])
return processed_params
class TavilySearchTool(BuiltinTool):
@ -111,14 +108,88 @@ class TavilySearchTool(BuiltinTool):
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the Tavily search tool invocation.
"""
query = tool_parameters.get("query", "")
api_key = self.runtime.credentials["tavily_api_key"]
api_key = self.runtime.credentials.get("tavily_api_key")
if not api_key:
return self.create_text_message(
"Tavily API key is missing. Please set the 'tavily_api_key' in credentials."
)
if not query:
return self.create_text_message("Please input query")
return self.create_text_message("Please input a query.")
tavily_search = TavilySearch(api_key)
results = tavily_search.results(tool_parameters)
print(results)
if not results:
return self.create_text_message(f"No results found for '{query}' in Tavily")
try:
raw_results = tavily_search.raw_results(tool_parameters)
except requests.HTTPError as e:
return self.create_text_message(f"Error occurred while searching: {str(e)}")
if not raw_results.get("results"):
return self.create_text_message(f"No results found for '{query}' in Tavily.")
else:
return self.create_text_message(text=results)
# Always return JSON message with all data
json_message = self.create_json_message(raw_results)
# Create text message based on user-selected parameters
text_message_content = self._format_results_as_text(raw_results, tool_parameters)
text_message = self.create_text_message(text=text_message_content)
return [json_message, text_message]
def _format_results_as_text(self, raw_results: dict, tool_parameters: dict[str, Any]) -> str:
"""
Formats the raw results into a markdown text based on user-selected parameters.
Args:
raw_results (dict): The raw search results.
tool_parameters (dict): The tool parameters selected by the user.
Returns:
str: The formatted markdown text.
"""
output_lines = []
# Include answer if requested
if tool_parameters.get("include_answer", False) and raw_results.get("answer"):
output_lines.append(f"**Answer:** {raw_results['answer']}\n")
# Include images if requested
if tool_parameters.get("include_images", False) and raw_results.get("images"):
output_lines.append("**Images:**\n")
for image in raw_results["images"]:
if tool_parameters.get("include_image_descriptions", False) and "description" in image:
output_lines.append(f"![{image['description']}]({image['url']})\n")
else:
output_lines.append(f"![]({image['url']})\n")
# Process each result
if "results" in raw_results:
for idx, result in enumerate(raw_results["results"], 1):
title = result.get("title", "No Title")
url = result.get("url", "")
content = result.get("content", "")
published_date = result.get("published_date", "")
score = result.get("score", "")
output_lines.append(f"### Result {idx}: [{title}]({url})\n")
# Include published date if available and topic is 'news'
if tool_parameters.get("topic") == "news" and published_date:
output_lines.append(f"**Published Date:** {published_date}\n")
output_lines.append(f"**URL:** {url}\n")
# Include score (relevance)
if score:
output_lines.append(f"**Relevance Score:** {score}\n")
# Include content
if content:
output_lines.append(f"**Content:**\n{content}\n")
# Include raw content if requested
if tool_parameters.get("include_raw_content", False) and result.get("raw_content"):
output_lines.append(f"**Raw Content:**\n{result['raw_content']}\n")
# Add a separator
output_lines.append("---\n")
return "\n".join(output_lines)

View File

@ -2,28 +2,24 @@ identity:
name: tavily_search
author: Yash Parmar
label:
en_US: TavilySearch
zh_Hans: TavilySearch
pt_BR: TavilySearch
en_US: Tavily Search
zh_Hans: Tavily Search
description:
human:
en_US: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.
en_US: A search engine tool built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.
zh_Hans: 专为人工智能代理 (LLM) 构建的搜索引擎工具,可快速提供实时、准确和真实的结果。
pt_BR: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.
llm: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.
parameters:
- name: query
type: string
required: true
label:
en_US: Query string
zh_Hans: 查询语句
pt_BR: Query string
en_US: Query
zh_Hans: 查询
human_description:
en_US: used for searching
zh_Hans: 用于搜索网页内容
pt_BR: used for searching
llm_description: key words for searching
en_US: The search query you want to execute with Tavily.
zh_Hans: 您想用 Tavily 执行的搜索查询。
llm_description: The search query.
form: llm
- name: search_depth
type: select
@ -31,122 +27,118 @@ parameters:
label:
en_US: Search Depth
zh_Hans: 搜索深度
pt_BR: Search Depth
human_description:
en_US: The depth of search results
zh_Hans: 搜索结果的深度
pt_BR: The depth of search results
en_US: The depth of the search.
zh_Hans: 搜索的深度。
form: form
options:
- value: basic
label:
en_US: Basic
zh_Hans: 基本
pt_BR: Basic
- value: advanced
label:
en_US: Advanced
zh_Hans: 高级
pt_BR: Advanced
default: basic
- name: topic
type: select
required: false
label:
en_US: Topic
zh_Hans: 主题
human_description:
en_US: The category of the search.
zh_Hans: 搜索的类别。
form: form
options:
- value: general
label:
en_US: General
zh_Hans: 一般
- value: news
label:
en_US: News
zh_Hans: 新闻
default: general
- name: days
type: number
required: false
label:
en_US: Days
zh_Hans: 天数
human_description:
en_US: The number of days back from the current date to include in the search results (only applicable when "topic" is "news").
zh_Hans: 从当前日期起向前追溯的天数以包含在搜索结果中仅当“topic”为“news”时适用
form: form
min: 1
default: 3
- name: max_results
type: number
required: false
label:
en_US: Max Results
zh_Hans: 最大结果数
human_description:
en_US: The maximum number of search results to return.
zh_Hans: 要返回的最大搜索结果数。
form: form
min: 1
max: 20
default: 5
- name: include_images
type: boolean
required: false
label:
en_US: Include Images
zh_Hans: 包含图片
pt_BR: Include Images
human_description:
en_US: Include images in the search results
zh_Hans: 在搜索结果中包含图片
pt_BR: Include images in the search results
en_US: Include a list of query-related images in the response.
zh_Hans: 在响应中包含与查询相关的图片列表。
form: form
options:
- value: 'true'
label:
en_US: 'Yes'
zh_Hans:
pt_BR: 'Yes'
- value: 'false'
label:
en_US: 'No'
zh_Hans:
pt_BR: 'No'
default: 'false'
default: false
- name: include_image_descriptions
type: boolean
required: false
label:
en_US: Include Image Descriptions
zh_Hans: 包含图片描述
human_description:
en_US: When include_images is True, adds descriptive text for each image.
zh_Hans: 当 include_images 为 True 时,为每个图像添加描述文本。
form: form
default: false
- name: include_answer
type: boolean
required: false
label:
en_US: Include Answer
zh_Hans: 包含答案
pt_BR: Include Answer
human_description:
en_US: Include answers in the search results
zh_Hans: 在搜索结果中包含答案
pt_BR: Include answers in the search results
en_US: Include a short answer to the original query in the response.
zh_Hans: 在响应中包含对原始查询的简短回答。
form: form
options:
- value: 'true'
label:
en_US: 'Yes'
zh_Hans:
pt_BR: 'Yes'
- value: 'false'
label:
en_US: 'No'
zh_Hans:
pt_BR: 'No'
default: 'false'
default: false
- name: include_raw_content
type: boolean
required: false
label:
en_US: Include Raw Content
zh_Hans: 包含原始内容
pt_BR: Include Raw Content
human_description:
en_US: Include raw content in the search results
zh_Hans: 在搜索结果中包含原始内容
pt_BR: Include raw content in the search results
en_US: Include the cleaned and parsed HTML content of each search result.
zh_Hans: 包含每个搜索结果的已清理和解析的HTML内容。
form: form
options:
- value: 'true'
label:
en_US: 'Yes'
zh_Hans:
pt_BR: 'Yes'
- value: 'false'
label:
en_US: 'No'
zh_Hans:
pt_BR: 'No'
default: 'false'
- name: max_results
type: number
required: false
label:
en_US: Max Results
zh_Hans: 最大结果
pt_BR: Max Results
human_description:
en_US: The number of maximum search results to return
zh_Hans: 返回的最大搜索结果数
pt_BR: The number of maximum search results to return
form: form
min: 1
max: 20
default: 5
default: false
- name: include_domains
type: string
required: false
label:
en_US: Include Domains
zh_Hans: 包含域
pt_BR: Include Domains
human_description:
en_US: A list of domains to specifically include in the search results
zh_Hans: 在搜索结果中特别包含的域名列表
pt_BR: A list of domains to specifically include in the search results
en_US: A comma-separated list of domains to specifically include in the search results.
zh_Hans: 要在搜索结果中特别包含的域的逗号分隔列表。
form: form
- name: exclude_domains
type: string
@ -154,9 +146,7 @@ parameters:
label:
en_US: Exclude Domains
zh_Hans: 排除域
pt_BR: Exclude Domains
human_description:
en_US: A list of domains to specifically exclude from the search results
zh_Hans: 从搜索结果中特别排除的域名列表
pt_BR: A list of domains to specifically exclude from the search results
en_US: A comma-separated list of domains to specifically exclude from the search results.
zh_Hans: 要从搜索结果中特别排除的域的逗号分隔列表。
form: form

View File

@ -0,0 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="800px" height="800px" viewBox="0 -38 256 256" version="1.1" xmlns="http://www.w3.org/2000/svg"
xmlns:xlink="http://www.w3.org/1999/xlink" preserveAspectRatio="xMidYMid">
<g>
<path d="M250.346231,28.0746923 C247.358133,17.0320558 238.732098,8.40602109 227.689461,5.41792308 C207.823743,0 127.868333,0 127.868333,0 C127.868333,0 47.9129229,0.164179487 28.0472049,5.58210256 C17.0045684,8.57020058 8.37853373,17.1962353 5.39043571,28.2388718 C-0.618533519,63.5374615 -2.94988224,117.322662 5.5546152,151.209308 C8.54271322,162.251944 17.1687479,170.877979 28.2113844,173.866077 C48.0771024,179.284 128.032513,179.284 128.032513,179.284 C128.032513,179.284 207.987923,179.284 227.853641,173.866077 C238.896277,170.877979 247.522312,162.251944 250.51041,151.209308 C256.847738,115.861464 258.801474,62.1091 250.346231,28.0746923 Z"
fill="#FF0000">
</path>
<polygon fill="#FFFFFF" points="102.420513 128.06 168.749025 89.642 102.420513 51.224">
</polygon>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.0 KiB

View File

@ -0,0 +1,81 @@
from typing import Any, Union
from urllib.parse import parse_qs, urlparse
from youtube_transcript_api import YouTubeTranscriptApi
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class YouTubeTranscriptTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke the YouTube transcript tool
"""
try:
# Extract parameters with defaults
video_input = tool_parameters["video_id"]
language = tool_parameters.get("language")
output_format = tool_parameters.get("format", "text")
preserve_formatting = tool_parameters.get("preserve_formatting", False)
proxy = tool_parameters.get("proxy")
cookies = tool_parameters.get("cookies")
# Extract video ID from URL if needed
video_id = self._extract_video_id(video_input)
# Common kwargs for API calls
kwargs = {"proxies": {"https": proxy} if proxy else None, "cookies": cookies}
try:
if language:
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id, **kwargs)
try:
transcript = transcript_list.find_transcript([language])
except:
# If requested language not found, try translating from English
transcript = transcript_list.find_transcript(["en"]).translate(language)
transcript_data = transcript.fetch()
else:
transcript_data = YouTubeTranscriptApi.get_transcript(
video_id, preserve_formatting=preserve_formatting, **kwargs
)
# Format output
formatter_class = {
"json": "JSONFormatter",
"pretty": "PrettyPrintFormatter",
"srt": "SRTFormatter",
"vtt": "WebVTTFormatter",
}.get(output_format)
if formatter_class:
from youtube_transcript_api import formatters
formatter = getattr(formatters, formatter_class)()
formatted_transcript = formatter.format_transcript(transcript_data)
else:
formatted_transcript = " ".join(entry["text"] for entry in transcript_data)
return self.create_text_message(text=formatted_transcript)
except Exception as e:
return self.create_text_message(text=f"Error getting transcript: {str(e)}")
except Exception as e:
return self.create_text_message(text=f"Error processing request: {str(e)}")
def _extract_video_id(self, video_input: str) -> str:
"""
Extract video ID from URL or return as-is if already an ID
"""
if "youtube.com" in video_input or "youtu.be" in video_input:
# Parse URL
parsed_url = urlparse(video_input)
if "youtube.com" in parsed_url.netloc:
return parse_qs(parsed_url.query)["v"][0]
else: # youtu.be
return parsed_url.path[1:]
return video_input # Assume it's already a video ID

View File

@ -0,0 +1,101 @@
identity:
name: free_youtube_transcript
author: Tao Wang
label:
en_US: Free YouTube Transcript API
zh_Hans: 免费获取 YouTube 转录
description:
human:
en_US: Get transcript from a YouTube video for free.
zh_Hans: 免费获取 YouTube 视频的转录文案。
llm: A tool for retrieving transcript from YouTube videos.
parameters:
- name: video_id
type: string
required: true
label:
en_US: Video ID/URL
zh_Hans: 视频ID
human_description:
en_US: Used to define the video from which the transcript will be fetched. You can find the id in the video url. For example - https://www.youtube.com/watch?v=video_id.
zh_Hans: 您要哪条视频的转录文案您可以在视频链接中找到id。例如 - https://www.youtube.com/watch?v=video_id。
llm_description: Used to define the video from which the transcript will be fetched. For example - https://www.youtube.com/watch?v=video_id.
form: llm
- name: language
type: string
required: false
label:
en_US: Language Code
zh_Hans: 语言
human_description:
en_US: Language code (e.g. 'en', 'zh') for the transcript.
zh_Hans: 字幕语言代码(如'en'、'zh')。留空则自动选择。
llm_description: Used to set the language for transcripts.
form: form
- name: format
type: select
required: false
default: text
options:
- value: text
label:
en_US: Plain Text
zh_Hans: 纯文本
- value: json
label:
en_US: JSON Format
zh_Hans: JSON 格式
- value: pretty
label:
en_US: Pretty Print Format
zh_Hans: 美化格式
- value: srt
label:
en_US: SRT Format
zh_Hans: SRT 格式
- value: vtt
label:
en_US: WebVTT Format
zh_Hans: WebVTT 格式
label:
en_US: Output Format
zh_Hans: 输出格式
human_description:
en_US: Format of the transcript output
zh_Hans: 字幕输出格式
llm_description: The format to output the transcript in. Options are text (plain text), json (raw transcript data), srt (SubRip format), or vtt (WebVTT format)
form: form
- name: preserve_formatting
type: boolean
required: false
default: false
label:
en_US: Preserve Formatting
zh_Hans: 保留格式
human_description:
en_US: Keep HTML formatting elements like <i> (italics) and <b> (bold)
zh_Hans: 保留HTML格式元素如<i>(斜体)和<b>(粗体)
llm_description: Whether to preserve HTML formatting elements in the transcript text
form: form
- name: proxy
type: string
required: false
label:
en_US: HTTPS Proxy
zh_Hans: HTTPS 代理
human_description:
en_US: HTTPS proxy URL (e.g. https://user:pass@domain:port)
zh_Hans: HTTPS 代理地址(如 https://user:pass@domain:port
llm_description: HTTPS proxy to use for the request. Format should be https://user:pass@domain:port
form: form
- name: cookies
type: string
required: false
label:
en_US: Cookies File Path
zh_Hans: Cookies 文件路径
human_description:
en_US: Path to cookies.txt file for accessing age-restricted videos
zh_Hans: 用于访问年龄限制视频的 cookies.txt 文件路径
llm_description: Path to a cookies.txt file containing YouTube cookies, needed for accessing age-restricted videos
form: form

Some files were not shown because too many files have changed in this diff Show More