mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
99ffe43e91
|
|
@ -42,6 +42,11 @@ REDIS_SENTINEL_USERNAME=
|
|||
REDIS_SENTINEL_PASSWORD=
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
|
||||
|
||||
# redis Cluster configuration.
|
||||
REDIS_USE_CLUSTERS=false
|
||||
REDIS_CLUSTERS=
|
||||
REDIS_CLUSTERS_PASSWORD=
|
||||
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
|
|
@ -234,6 +239,10 @@ ANALYTICDB_ACCOUNT=testaccount
|
|||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
ANALYTICDB_HOST=gp-test.aliyuncs.com
|
||||
ANALYTICDB_PORT=5432
|
||||
ANALYTICDB_MIN_CONNECTION=1
|
||||
ANALYTICDB_MAX_CONNECTION=5
|
||||
|
||||
# OpenSearch configuration
|
||||
OPENSEARCH_HOST=127.0.0.1
|
||||
|
|
|
|||
|
|
@ -589,7 +589,7 @@ def upgrade_db():
|
|||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"Database migration failed: {e}")
|
||||
logging.exception("Failed to execute database migration")
|
||||
finally:
|
||||
lock.release()
|
||||
else:
|
||||
|
|
@ -633,7 +633,7 @@ where sites.id is null limit 1000"""
|
|||
except Exception as e:
|
||||
failed_app_ids.append(app_id)
|
||||
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
|
||||
logging.exception(f"Fix app related site missing issue failed, error: {e}")
|
||||
logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
|
||||
continue
|
||||
|
||||
if not processed_count:
|
||||
|
|
|
|||
|
|
@ -616,6 +616,11 @@ class DataSetConfig(BaseSettings):
|
|||
default=False,
|
||||
)
|
||||
|
||||
PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING: PositiveInt = Field(
|
||||
description="Interval in days for message cleanup operations - plan: sandbox",
|
||||
default=30,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -68,3 +68,18 @@ class RedisConfig(BaseSettings):
|
|||
description="Socket timeout in seconds for Redis Sentinel connections",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
REDIS_USE_CLUSTERS: bool = Field(
|
||||
description="Enable Redis Clusters mode for high availability",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_CLUSTERS: Optional[str] = Field(
|
||||
description="Comma-separated list of Redis Clusters nodes (host:port)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_CLUSTERS_PASSWORD: Optional[str] = Field(
|
||||
description="Password for Redis Clusters authentication (if required)",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
|
|
@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel):
|
|||
description="The password for accessing the specified namespace within the AnalyticDB instance"
|
||||
" (if namespace feature is enabled).",
|
||||
)
|
||||
ANALYTICDB_HOST: Optional[str] = Field(
|
||||
default=None, description="The host of the AnalyticDB instance you want to connect to."
|
||||
)
|
||||
ANALYTICDB_PORT: PositiveInt = Field(
|
||||
default=5432, description="The port of the AnalyticDB instance you want to connect to."
|
||||
)
|
||||
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
|
||||
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.11.1",
|
||||
default="0.11.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
|
|
@ -28,6 +29,7 @@ class AppListApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
"""Get app list"""
|
||||
|
||||
|
|
@ -149,6 +151,7 @@ class AppApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class ChatMessageAudioApi(Resource):
|
|||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
logging.exception("Failed to handle post request to ChatMessageAudioApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
|
@ -128,7 +128,7 @@ class ChatMessageTextApi(Resource):
|
|||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
logging.exception("Failed to handle post request to ChatMessageTextApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
|
@ -170,7 +170,7 @@ class TextModesApi(Resource):
|
|||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
logging.exception("Failed to handle get request to TextModesApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from controllers.console.auth.error import (
|
|||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -48,7 +48,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
return {"result": "fail", "data": token, "code": "account_not_found"}
|
||||
else:
|
||||
raise NotAllowedRegister()
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ from controllers.console.auth.error import (
|
|||
)
|
||||
from controllers.console.error import (
|
||||
AccountBannedError,
|
||||
AccountNotFound,
|
||||
EmailSendIpLimitError,
|
||||
NotAllowedCreateWorkspace,
|
||||
NotAllowedRegister,
|
||||
)
|
||||
from controllers.console.wraps import setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
|
|
@ -76,7 +76,7 @@ class LoginApi(Resource):
|
|||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
return {"result": "fail", "data": token, "code": "account_not_found"}
|
||||
else:
|
||||
raise NotAllowedRegister()
|
||||
raise AccountNotFound()
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
|
|
@ -119,7 +119,7 @@ class ResetPasswordSendEmailApi(Resource):
|
|||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
else:
|
||||
raise NotAllowedRegister()
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, language=language)
|
||||
|
||||
|
|
@ -148,7 +148,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
|
||||
else:
|
||||
raise NotAllowedRegister()
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
token = AccountService.send_email_code_login_email(account=account, language=language)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from controllers.console import api
|
|||
from controllers.console.apikey import api_key_fields, api_key_list
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
|
@ -44,6 +44,7 @@ class DatasetListApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
|
|
|
|||
|
|
@ -948,7 +948,7 @@ class DocumentRetryApi(DocumentResource):
|
|||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception as e:
|
||||
logging.exception(f"Document {document_id} retry failed: {str(e)}")
|
||||
logging.exception(f"Failed to retry document, document id: {document_id}")
|
||||
continue
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
|
|
|||
|
|
@ -52,8 +52,8 @@ class AccountBannedError(BaseHTTPException):
|
|||
code = 400
|
||||
|
||||
|
||||
class NotAllowedRegister(BaseHTTPException):
|
||||
error_code = "unauthorized"
|
||||
class AccountNotFound(BaseHTTPException):
|
||||
error_code = "account_not_found"
|
||||
description = "Account not found."
|
||||
code = 400
|
||||
|
||||
|
|
@ -86,3 +86,9 @@ class NoFileUploadedError(BaseHTTPException):
|
|||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class UnauthorizedAndForceLogout(BaseHTTPException):
|
||||
error_code = "unauthorized_and_force_logout"
|
||||
description = "Unauthorized and force logout."
|
||||
code = 401
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class RemoteFileUploadApi(Resource):
|
|||
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3)
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from controllers.console.workspace.error import (
|
|||
InvalidInvitationCodeError,
|
||||
RepeatPasswordNotMatchError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.helper import TimestampField, timezone
|
||||
|
|
@ -79,6 +79,7 @@ class AccountProfileApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
return current_user
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from urllib import parse
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, marshal_with, reqparse
|
||||
|
||||
|
|
@ -57,11 +59,12 @@ class MemberInviteEmailApi(Resource):
|
|||
token = RegisterService.invite_new_member(
|
||||
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
||||
)
|
||||
encoded_invitee_email = parse.quote(invitee_email)
|
||||
invitation_results.append(
|
||||
{
|
||||
"status": "success",
|
||||
"email": invitee_email,
|
||||
"url": f"{console_web_url}/activate?email={invitee_email}&token={token}",
|
||||
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
|
||||
}
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
|
|
|
|||
|
|
@ -72,7 +72,10 @@ class DefaultModelApi(Resource):
|
|||
model=model_setting["model"],
|
||||
)
|
||||
except Exception as ex:
|
||||
logging.exception(f"{model_setting['model_type']} save error: {ex}")
|
||||
logging.exception(
|
||||
f"Failed to update default model, model type: {model_setting['model_type']},"
|
||||
f" model:{model_setting.get('model')}"
|
||||
)
|
||||
raise ex
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
@ -156,7 +159,10 @@ class ModelProviderModelApi(Resource):
|
|||
credentials=args["credentials"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
logging.exception(f"save model credentials error: {ex}")
|
||||
logging.exception(
|
||||
f"Failed to save model credentials, tenant_id: {tenant_id},"
|
||||
f" model: {args.get('model')}, model_type: {args.get('model_type')}"
|
||||
)
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from werkzeug.exceptions import Forbidden
|
|||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
|
|
@ -549,6 +549,7 @@ class ToolLabelsApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from flask_login import current_user
|
|||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
from services.operation_service import OperationService
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
|
||||
def account_initialization_required(view):
|
||||
|
|
@ -142,3 +142,15 @@ def setup_required(view):
|
|||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_license_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
|
||||
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class AudioApi(WebApiResource):
|
|||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
logging.exception("Failed to handle post request to AudioApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
|
@ -117,7 +117,7 @@ class TextApi(WebApiResource):
|
|||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
logging.exception("Failed to handle post request to TextApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,9 +16,7 @@ class FileUploadConfigManager:
|
|||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get("enabled"):
|
||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get(
|
||||
"allowed_upload_methods", []
|
||||
)
|
||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
|
||||
data = {
|
||||
"image_config": {
|
||||
"number_limits": file_upload_dict["number_limits"],
|
||||
|
|
|
|||
|
|
@ -362,5 +362,5 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}")
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -242,7 +242,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.exception(f"Failed to listen audio message, task_id: {task_id}")
|
||||
break
|
||||
if tts_publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ class BaseAppGenerator:
|
|||
tenant_id=app_config.tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
|
|
@ -47,8 +47,8 @@ class BaseAppGenerator:
|
|||
tenant_id=app_config.tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
|
|
@ -91,6 +91,9 @@ class BaseAppGenerator:
|
|||
)
|
||||
|
||||
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
|
||||
# handle empty string case
|
||||
if not value.strip():
|
||||
return None
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if "." in value:
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
logger.exception(f"Failed to handle response, conversation_id: {conversation.id}")
|
||||
raise e
|
||||
|
||||
def _get_conversation_by_user(
|
||||
|
|
|
|||
|
|
@ -298,5 +298,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
logger.exception(
|
||||
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
|
||||
)
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -216,7 +216,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
|
||||
break
|
||||
if tts_publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class MessageCycleManage:
|
|||
conversation.name = name
|
||||
except Exception as e:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(f"generate conversation name failed: {e}")
|
||||
logging.exception(f"generate conversation name failed, conversation_id: {conversation_id}")
|
||||
pass
|
||||
|
||||
db.session.merge(conversation)
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ class FileUploadConfig(BaseModel):
|
|||
|
||||
image_config: Optional[ImageConfig] = None
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
number_limits: int = 0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
|
|||
if moderation_result is True:
|
||||
return True
|
||||
except Exception as ex:
|
||||
logger.exception(ex)
|
||||
logger.exception(f"Fails to check moderation, provider_name: {provider_name}")
|
||||
raise InvokeBadRequestError("Rate limit exceeded, please try again later.")
|
||||
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
|
|||
spec.loader.exec_module(module)
|
||||
return module
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}")
|
||||
logging.exception(f"Failed to load module {module_name} from script file '{py_file_path}'")
|
||||
raise e
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
)
|
||||
|
||||
retries = 0
|
||||
stream = kwargs.pop("stream", False)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
|
|
@ -52,6 +53,8 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
response = client.request(method=method, url=url, **kwargs)
|
||||
|
||||
if response.status_code not in STATUS_FORCELIST:
|
||||
if stream:
|
||||
return response.iter_bytes()
|
||||
return response
|
||||
else:
|
||||
logging.warning(
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from core.rag.splitter.fixed_text_splitter import (
|
|||
FixedRecursiveCharacterTextSplitter,
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
|
@ -500,11 +501,7 @@ class IndexingRunner:
|
|||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
document_node.page_content = remove_leading_symbols(page_content)
|
||||
|
||||
if document_node.page_content:
|
||||
split_documents.append(document_node)
|
||||
|
|
@ -554,7 +551,7 @@ class IndexingRunner:
|
|||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
logging.exception("Failed to format qa document")
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ class LLMGenerator:
|
|||
except InvokeError:
|
||||
questions = []
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
logging.exception("Failed to generate suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
return questions
|
||||
|
|
@ -148,7 +148,7 @@ class LLMGenerator:
|
|||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}")
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
|
@ -234,7 +234,7 @@ class LLMGenerator:
|
|||
error_step = "generate conversation opener"
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}")
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
|
@ -286,7 +286,9 @@ class LLMGenerator:
|
|||
error = str(e)
|
||||
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
logging.exception(
|
||||
f"Failed to invoke LLM model, model: {model_config.get('name')}, language: {code_language}"
|
||||
)
|
||||
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -325,14 +325,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
assistant_prompt_message.tool_calls.append(tool_call)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.input_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
prompt_tokens = (response.usage and response.usage.input_tokens) or self.get_num_tokens(
|
||||
model, credentials, prompt_messages
|
||||
)
|
||||
|
||||
completion_tokens = (response.usage and response.usage.output_tokens) or self.get_num_tokens(
|
||||
model, credentials, [assistant_prompt_message]
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ class AzureRerankModel(RerankModel):
|
|||
return RerankResult(model=model, docs=rerank_documents)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception in Azure rerank: {e}")
|
||||
logger.exception(f"Failed to invoke rerank model, model: {model}")
|
||||
raise
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
|
|
|||
|
|
@ -2,13 +2,11 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
# 3rd import
|
||||
import boto3
|
||||
import requests
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
|
|
@ -439,22 +437,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
# fetch image data from url
|
||||
try:
|
||||
url = message_content.data
|
||||
image_content = requests.get(url).content
|
||||
if "?" in url:
|
||||
url = url.split("?")[0]
|
||||
mime_type, _ = mimetypes.guess_type(url)
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
image_content = base64.b64decode(base64_data)
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
image_content = base64.b64decode(base64_data)
|
||||
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -11,5 +11,6 @@
|
|||
- gemini-1.5-flash-exp-0827
|
||||
- gemini-1.5-flash-8b-exp-0827
|
||||
- gemini-1.5-flash-8b-exp-0924
|
||||
- gemini-exp-1114
|
||||
- gemini-pro
|
||||
- gemini-pro-vision
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
model: gemini-exp-1114
|
||||
label:
|
||||
en_US: Gemini exp 1114
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -32,3 +32,4 @@ pricing:
|
|||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
|
|
|||
|
|
@ -36,3 +36,4 @@ pricing:
|
|||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
|
|
@ -36,17 +35,6 @@ from core.model_runtime.errors.invoke import (
|
|||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(
|
||||
|
|
@ -155,7 +143,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
||||
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
|
@ -184,7 +172,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
||||
if schema := config_kwargs.pop("json_schema", None):
|
||||
try:
|
||||
schema = json.loads(schema)
|
||||
except:
|
||||
raise exceptions.InvalidArgument("Invalid JSON Schema")
|
||||
if tools:
|
||||
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
|
||||
config_kwargs["response_schema"] = schema
|
||||
config_kwargs["response_mime_type"] = "application/json"
|
||||
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
|
|
@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
|
|
@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
|
|
@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, "api/chat")
|
||||
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
if tools:
|
||||
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
|
||||
else:
|
||||
endpoint_url = urljoin(endpoint_url, "api/generate")
|
||||
first_prompt_message = prompt_messages[0]
|
||||
|
|
@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)
|
||||
|
||||
def _handle_generate_response(
|
||||
self,
|
||||
|
|
@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
completion_type: LLMMode,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm completion response
|
||||
|
|
@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
:return: llm result
|
||||
"""
|
||||
response_json = response.json()
|
||||
|
||||
tool_calls = []
|
||||
if completion_type is LLMMode.CHAT:
|
||||
message = response_json.get("message", {})
|
||||
response_content = message.get("content", "")
|
||||
response_tool_calls = message.get("tool_calls", [])
|
||||
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
|
||||
else:
|
||||
response_content = response_json["response"]
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=response_content)
|
||||
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)
|
||||
|
||||
if "prompt_eval_count" in response_json and "eval_count" in response_json:
|
||||
# transform usage
|
||||
|
|
@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
chunk_index += 1
|
||||
|
||||
def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
|
||||
"""
|
||||
Convert PromptMessageTool to dict for Ollama API
|
||||
|
||||
:param tool: tool
|
||||
:return: tool dict
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for Ollama API
|
||||
|
||||
:param message: prompt message
|
||||
:return: message dict
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
|
|
@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "tool", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
|
@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
return num_tokens
|
||||
|
||||
def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract response tool call
|
||||
"""
|
||||
tool_call = None
|
||||
if response_tool_call and "function" in response_tool_call:
|
||||
# Convert arguments to JSON string if it's a dict
|
||||
arguments = response_tool_call.get("function").get("arguments")
|
||||
if isinstance(arguments, dict):
|
||||
arguments = json.dumps(arguments)
|
||||
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.get("function").get("name"),
|
||||
arguments=arguments,
|
||||
)
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.get("function").get("name"),
|
||||
type="function",
|
||||
function=function,
|
||||
)
|
||||
|
||||
return tool_call
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
Get customizable model schema.
|
||||
|
|
@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
:return: model schema
|
||||
"""
|
||||
extras = {}
|
||||
extras = {
|
||||
"features": [],
|
||||
}
|
||||
|
||||
if "vision_support" in credentials and credentials["vision_support"] == "true":
|
||||
extras["features"] = [ModelFeature.VISION]
|
||||
extras["features"].append(ModelFeature.VISION)
|
||||
if "function_call_support" in credentials and credentials["function_call_support"] == "true":
|
||||
extras["features"].append(ModelFeature.TOOL_CALL)
|
||||
extras["features"].append(ModelFeature.MULTI_TOOL_CALL)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
|
|
|||
|
|
@ -96,3 +96,22 @@ model_credential_schema:
|
|||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
- variable: function_call_support
|
||||
label:
|
||||
zh_Hans: 是否支持函数调用
|
||||
en_US: Function call support
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: 'false'
|
||||
type: radio
|
||||
required: false
|
||||
options:
|
||||
- value: 'true'
|
||||
label:
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: 'false'
|
||||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
|
|
|
|||
|
|
@ -615,19 +615,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# o1 compatibility
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
if "max_tokens" in model_parameters:
|
||||
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
||||
del model_parameters["max_tokens"]
|
||||
|
||||
if stream:
|
||||
block_as_stream = True
|
||||
stream = False
|
||||
|
||||
if "stream_options" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stream_options"]
|
||||
|
||||
if "stop" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stop"]
|
||||
|
||||
|
|
@ -644,47 +636,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
if block_as_stream:
|
||||
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
||||
|
||||
return block_result
|
||||
|
||||
def _handle_chat_block_as_stream_response(
|
||||
self,
|
||||
block_result: LLMResult,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
text = block_result.message.content
|
||||
text = cast(str, text)
|
||||
|
||||
if stop:
|
||||
text = self.enforce_stop_tokens(text, stop)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=block_result.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=block_result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
finish_reason="stop",
|
||||
usage=block_result.usage,
|
||||
),
|
||||
)
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -45,19 +45,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._update_credential(model, credentials)
|
||||
|
||||
block_as_stream = False
|
||||
if model.startswith("openai/o1"):
|
||||
block_as_stream = True
|
||||
stop = None
|
||||
|
||||
# invoke block as stream
|
||||
if stream and block_as_stream:
|
||||
return self._generate_block_as_stream(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, user
|
||||
)
|
||||
else:
|
||||
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _generate_block_as_stream(
|
||||
self,
|
||||
|
|
@ -69,9 +57,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||
stop: Optional[list[str]] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Generator:
|
||||
resp: LLMResult = super()._generate(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, False, user
|
||||
)
|
||||
resp = super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, False, user)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ class SageMakerRerankModel(RerankModel):
|
|||
return RerankResult(model=model, docs=rerank_documents)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception {e}, line : {line}")
|
||||
logger.exception(f"Failed to invoke rerank model, model: {model}")
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class SageMakerSpeech2TextModel(Speech2TextModel):
|
|||
json_obj = json.loads(json_str)
|
||||
asr_text = json_obj["text"]
|
||||
except Exception as e:
|
||||
logger.exception(f"failed to invoke speech2text model, {e}")
|
||||
logger.exception(f"failed to invoke speech2text model, model: {model}")
|
||||
raise CredentialsValidateFailedError(str(e))
|
||||
|
||||
return asr_text
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel):
|
|||
return TextEmbeddingResult(embeddings=all_embeddings, usage=usage, model=model)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception {e}, line : {line}")
|
||||
logger.exception(f"Failed to invoke text embedding model, model: {model}, line: {line}")
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ class GTERerankModel(RerankModel):
|
|||
)
|
||||
|
||||
rerank_documents = []
|
||||
if not response.output:
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
for _, result in enumerate(response.output.results):
|
||||
# format document
|
||||
rerank_document = RerankDocument(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
|
||||
|
||||
|
|
@ -62,5 +65,5 @@ class KeywordsModeration(Moderation):
|
|||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||
|
||||
def _check_keywords_in_value(self, keywords_list, value) -> bool:
|
||||
return any(keyword.lower() in value.lower() for keyword in keywords_list)
|
||||
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
|
||||
return any(keyword.lower() in str(value).lower() for keyword in keywords_list)
|
||||
|
|
|
|||
|
|
@ -126,6 +126,6 @@ class OutputModeration(BaseModel):
|
|||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception("Moderation Output error: %s", e)
|
||||
logger.exception(f"Moderation Output error, app_id: {app_id}")
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
|||
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
|
||||
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
|
||||
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
||||
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
|||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
|
@ -62,6 +62,16 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.message_id or trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
message_dotted_order = (
|
||||
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
|
||||
)
|
||||
workflow_dotted_order = generate_dotted_order(
|
||||
trace_info.workflow_app_log_id or trace_info.workflow_run_id,
|
||||
trace_info.workflow_data.created_at,
|
||||
message_dotted_order,
|
||||
)
|
||||
|
||||
if trace_info.message_id:
|
||||
message_run = LangSmithRunModel(
|
||||
id=trace_info.message_id,
|
||||
|
|
@ -76,6 +86,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
},
|
||||
tags=["message", "workflow"],
|
||||
error=trace_info.error,
|
||||
trace_id=trace_id,
|
||||
dotted_order=message_dotted_order,
|
||||
)
|
||||
self.add_run(message_run)
|
||||
|
||||
|
|
@ -95,6 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
error=trace_info.error,
|
||||
tags=["workflow"],
|
||||
parent_run_id=trace_info.message_id or None,
|
||||
trace_id=trace_id,
|
||||
dotted_order=workflow_dotted_order,
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
|
@ -177,6 +191,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
else:
|
||||
run_type = LangSmithRunType.tool
|
||||
|
||||
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
|
||||
langsmith_run = LangSmithRunModel(
|
||||
total_tokens=node_total_tokens,
|
||||
name=node_type,
|
||||
|
|
@ -191,6 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
},
|
||||
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
|
||||
tags=["node_execution"],
|
||||
id=node_execution_id,
|
||||
trace_id=trace_id,
|
||||
dotted_order=node_dotted_order,
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
|
|
|||
|
|
@ -711,7 +711,7 @@ class TraceQueueManager:
|
|||
trace_task.app_id = self.app_id
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception as e:
|
||||
logging.exception(f"Error adding trace task: {e}")
|
||||
logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
|
||||
finally:
|
||||
self.start_timer()
|
||||
|
||||
|
|
@ -730,7 +730,7 @@ class TraceQueueManager:
|
|||
if tasks:
|
||||
self.send_to_celery(tasks)
|
||||
except Exception as e:
|
||||
logging.exception(f"Error processing trace tasks: {e}")
|
||||
logging.exception("Error processing trace tasks")
|
||||
|
||||
def start_timer(self):
|
||||
global trace_manager_timer
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
|
@ -43,3 +44,19 @@ def replace_text_with_content(data):
|
|||
return [replace_text_with_content(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def generate_dotted_order(
|
||||
run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
generate dotted_order for langsmith
|
||||
"""
|
||||
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
|
||||
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
|
||||
current_segment = f"{timestamp}{run_id}"
|
||||
|
||||
if parent_dotted_order is None:
|
||||
return current_segment
|
||||
|
||||
return f"{parent_dotted_order}.{current_segment}"
|
||||
|
|
|
|||
|
|
@ -1,310 +1,62 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = ("dify",)
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = ("cosine",)
|
||||
read_timeout: int = 60000
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: AnalyticdbConfig):
|
||||
self._collection_name = collection_name.lower()
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
cache_key = f"vector_indexing_{self.config.instance_id}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
def __init__(
|
||||
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
|
||||
):
|
||||
super().__init__(collection_name)
|
||||
if api_config is not None:
|
||||
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
|
||||
else:
|
||||
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ANALYTICDB
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection_if_not_exists(dimension)
|
||||
self.add_texts(texts, embeddings)
|
||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
self.analyticdb_vector.delete_by_ids(ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
return self.analyticdb_vector.search_by_vector(query_vector)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.metadata.get("vector"),
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
except Exception as e:
|
||||
raise e
|
||||
self.analyticdb_vector.delete()
|
||||
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
|
|
@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
|||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
|
||||
|
||||
# handle optional params
|
||||
if dify_config.ANALYTICDB_KEY_ID is None:
|
||||
raise ValueError("ANALYTICDB_KEY_ID should not be None")
|
||||
if dify_config.ANALYTICDB_KEY_SECRET is None:
|
||||
raise ValueError("ANALYTICDB_KEY_SECRET should not be None")
|
||||
if dify_config.ANALYTICDB_REGION_ID is None:
|
||||
raise ValueError("ANALYTICDB_REGION_ID should not be None")
|
||||
if dify_config.ANALYTICDB_INSTANCE_ID is None:
|
||||
raise ValueError("ANALYTICDB_INSTANCE_ID should not be None")
|
||||
if dify_config.ANALYTICDB_ACCOUNT is None:
|
||||
raise ValueError("ANALYTICDB_ACCOUNT should not be None")
|
||||
if dify_config.ANALYTICDB_PASSWORD is None:
|
||||
raise ValueError("ANALYTICDB_PASSWORD should not be None")
|
||||
if dify_config.ANALYTICDB_NAMESPACE is None:
|
||||
raise ValueError("ANALYTICDB_NAMESPACE should not be None")
|
||||
if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None:
|
||||
raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None")
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
AnalyticdbConfig(
|
||||
if dify_config.ANALYTICDB_HOST is None:
|
||||
# implemented through OpenAPI
|
||||
apiConfig = AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID,
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID,
|
||||
|
|
@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
|||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
||||
),
|
||||
)
|
||||
sqlConfig = None
|
||||
else:
|
||||
# implemented through sql
|
||||
sqlConfig = AnalyticdbVectorBySqlConfig(
|
||||
host=dify_config.ANALYTICDB_HOST,
|
||||
port=dify_config.ANALYTICDB_PORT,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
|
||||
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
)
|
||||
apiConfig = None
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
apiConfig,
|
||||
sqlConfig,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,309 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = "dify"
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = "cosine"
|
||||
read_timeout: int = 60000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["access_key_id"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_ID is required")
|
||||
if not values["access_key_secret"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
|
||||
if not values["region_id"]:
|
||||
raise ValueError("config ANALYTICDB_REGION_ID is required")
|
||||
if not values["instance_id"]:
|
||||
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["namespace_password"]:
|
||||
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPI:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self._collection_name = collection_name.lower()
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
@ -0,0 +1,245 @@
|
|||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorBySqlConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
account: str
|
||||
account_password: str
|
||||
min_connection: int
|
||||
max_connection: int
|
||||
namespace: str = "dify"
|
||||
metrics: str = "cosine"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config ANALYTICDB_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config ANALYTICDB_PORT is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["min_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
|
||||
if not values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
|
||||
if values["min_connection"] > values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
|
||||
return values
|
||||
|
||||
|
||||
class AnalyticdbVectorBySql:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
|
||||
self._collection_name = collection_name.lower()
|
||||
self.databaseName = "knowledgebase"
|
||||
self.config = config
|
||||
self.table_name = f"{self.config.namespace}.{self._collection_name}"
|
||||
self.pool = None
|
||||
self._initialize()
|
||||
if not self.pool:
|
||||
self.pool = self._create_connection_pool()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
cache_key = f"vector_initialize_{self.config.host}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.host}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _create_connection_pool(self):
|
||||
return psycopg2.pool.SimpleConnectionPool(
|
||||
self.config.min_connection,
|
||||
self.config.max_connection,
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database=self.databaseName,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
self.pool.putconn(conn)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database="postgres",
|
||||
)
|
||||
conn.autocommit = True
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(f"CREATE DATABASE {self.databaseName}")
|
||||
except Exception as e:
|
||||
if "already exists" in str(e):
|
||||
return
|
||||
raise e
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
self.pool = self._create_connection_pool()
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
|
||||
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
|
||||
except Exception as e:
|
||||
if "already exists" not in str(e):
|
||||
raise e
|
||||
cur.execute(
|
||||
"CREATE OR REPLACE FUNCTION "
|
||||
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
|
||||
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
|
||||
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
|
||||
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
|
||||
"AS words_only;$function$"
|
||||
)
|
||||
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
|
||||
f"id text PRIMARY KEY,"
|
||||
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
|
||||
f"to_tsvector TSVECTOR"
|
||||
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
|
||||
)
|
||||
if embedding_dimension is not None:
|
||||
index_name = f"{self._collection_name}_embedding_idx"
|
||||
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
|
||||
cur.execute(
|
||||
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
|
||||
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
|
||||
f"pq_enable=0, external_storage=0)"
|
||||
)
|
||||
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
id_prefix = str(uuid.uuid4()) + "_"
|
||||
sql = f"""
|
||||
INSERT INTO {self.table_name}
|
||||
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
|
||||
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
|
||||
"""
|
||||
for i, doc in enumerate(documents):
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, sql, values)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
with self._get_cursor() as cur:
|
||||
query_vector_str = json.dumps(query_vector)
|
||||
query_vector_str = "{" + query_vector_str[1:-1] + "}"
|
||||
cur.execute(
|
||||
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
|
||||
(query_vector_str,),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, score, page_content, metadata = record
|
||||
if score > score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, page_content, metadata, score = record
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
|
@ -242,7 +242,7 @@ class CouchbaseVector(BaseVector):
|
|||
try:
|
||||
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.exception(f"Failed to delete documents, ids: {ids}")
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
query = f"""
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class LindormVectorStore(BaseVector):
|
|||
"ids": batch_ids}, _source=False)
|
||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Error fetching batch {batch_ids}: {e}")
|
||||
logger.exception(f"Error fetching batch {batch_ids}")
|
||||
return set()
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
||||
|
|
@ -99,7 +99,7 @@ class LindormVectorStore(BaseVector):
|
|||
)
|
||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Error fetching batch {batch_ids}: {e}")
|
||||
logger.exception(f"Error fetching batch ids: {batch_ids}")
|
||||
return set()
|
||||
|
||||
if ids is None:
|
||||
|
|
@ -187,7 +187,7 @@ class LindormVectorStore(BaseVector):
|
|||
logger.warning(
|
||||
f"Index '{self._collection_name}' does not exist. No deletion performed.")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error occurred while deleting the index: {e}")
|
||||
logger.exception(f"Error occurred while deleting the index: {self._collection_name}")
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
|
|
@ -213,7 +213,7 @@ class LindormVectorStore(BaseVector):
|
|||
response = self._client.search(
|
||||
index=self._collection_name, body=query)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing search: {e}")
|
||||
logger.exception(f"Error executing vector search, query: {query}")
|
||||
raise
|
||||
|
||||
docs_and_scores = []
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ class MyScaleVector(BaseVector):
|
|||
for r in self._client.query(sql).named_results()
|
||||
]
|
||||
except Exception as e:
|
||||
logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") # noqa:TRY401
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ class OpenSearchVector(BaseVector):
|
|||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing search: {e}")
|
||||
logger.exception(f"Error executing vector search, query: {query}")
|
||||
raise
|
||||
|
||||
docs = []
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class CacheEmbedding(Embeddings):
|
|||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except Exception as e:
|
||||
logging.exception("Failed transform embedding: %s", e)
|
||||
logging.exception("Failed transform embedding")
|
||||
cache_embeddings = []
|
||||
try:
|
||||
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||
|
|
@ -89,7 +89,7 @@ class CacheEmbedding(Embeddings):
|
|||
db.session.rollback()
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
logger.exception("Failed to embed documents: %s", ex)
|
||||
logger.exception("Failed to embed documents: %s")
|
||||
raise ex
|
||||
|
||||
return text_embeddings
|
||||
|
|
@ -112,7 +112,7 @@ class CacheEmbedding(Embeddings):
|
|||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(f"Failed to embed query text: {ex}")
|
||||
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")
|
||||
raise ex
|
||||
|
||||
try:
|
||||
|
|
@ -126,7 +126,7 @@ class CacheEmbedding(Embeddings):
|
|||
redis_client.setex(embedding_cache_key, 600, encoded_str)
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception("Failed to add embedding to redis %s", ex)
|
||||
logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'")
|
||||
raise ex
|
||||
|
||||
return embedding_results
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ class WordExtractor(BaseExtractor):
|
|||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.exception("Failed to parse HYPERLINK xml")
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
|||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
|
@ -43,11 +44,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:].strip()
|
||||
else:
|
||||
page_content = page_content
|
||||
page_content = remove_leading_symbols(document_node.page_content).strip()
|
||||
if len(page_content) > 0:
|
||||
document_node.page_content = page_content
|
||||
split_documents.append(document_node)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
|||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
|
@ -53,11 +54,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
document_node.page_content = remove_leading_symbols(page_content)
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
for i in range(0, len(all_documents), 10):
|
||||
|
|
@ -159,7 +156,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
logging.exception("Failed to format qa document")
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
|
|
|
|||
|
|
@ -36,23 +36,21 @@ class WeightRerankRunner(BaseRerankRunner):
|
|||
|
||||
:return:
|
||||
"""
|
||||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
doc_id = set()
|
||||
for document in documents:
|
||||
if document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.append(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
doc_id = document.metadata.get("doc_id")
|
||||
if doc_id not in doc_id:
|
||||
doc_id.add(doc_id)
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
rerank_documents = []
|
||||
query_scores = self._calculate_keyword_score(query, documents)
|
||||
|
||||
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
|
||||
|
||||
rerank_documents = []
|
||||
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
|
||||
# format document
|
||||
score = (
|
||||
self.weights.vector_setting.vector_weight * query_vector_score
|
||||
+ self.weights.keyword_setting.keyword_weight * query_score
|
||||
|
|
@ -61,7 +59,8 @@ class WeightRerankRunner(BaseRerankRunner):
|
|||
continue
|
||||
document.metadata["score"] = score
|
||||
rerank_documents.append(document)
|
||||
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
|
||||
rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True)
|
||||
return rerank_documents[:top_n] if top_n else rerank_documents
|
||||
|
||||
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
from typing import Any
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
SUMMARY_PROMPT = """
|
||||
User's query:
|
||||
{query}
|
||||
|
||||
Here are the news results:
|
||||
{content}
|
||||
|
||||
Please summarize the news in a few sentences.
|
||||
"""
|
||||
|
||||
|
||||
class DuckDuckGoNewsSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a news search using DuckDuckGo search engine.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
query_dict = {
|
||||
"keywords": tool_parameters.get("query"),
|
||||
"timelimit": tool_parameters.get("timelimit"),
|
||||
"max_results": tool_parameters.get("max_results"),
|
||||
"safesearch": "moderate",
|
||||
"region": "wt-wt",
|
||||
}
|
||||
try:
|
||||
response = list(DDGS().news(**query_dict))
|
||||
if not response:
|
||||
return [self.create_text_message("No news found matching your criteria.")]
|
||||
except Exception as e:
|
||||
return [self.create_text_message(f"Error searching news: {str(e)}")]
|
||||
|
||||
require_summary = tool_parameters.get("require_summary", False)
|
||||
|
||||
if require_summary:
|
||||
results = "\n".join([f"{res.get('title')}: {res.get('body')}" for res in response])
|
||||
results = self.summary_results(user_id=user_id, content=results, query=query_dict["keywords"])
|
||||
return self.create_text_message(text=results)
|
||||
|
||||
# Create rich markdown content for each news item
|
||||
markdown_result = "\n\n"
|
||||
json_result = []
|
||||
|
||||
for res in response:
|
||||
markdown_result += f"### {res.get('title', 'Untitled')}\n\n"
|
||||
if res.get("date"):
|
||||
markdown_result += f"**Date:** {res.get('date')}\n\n"
|
||||
if res.get("body"):
|
||||
markdown_result += f"{res.get('body')}\n\n"
|
||||
if res.get("source"):
|
||||
markdown_result += f"*Source: {res.get('source')}*\n\n"
|
||||
if res.get("image"):
|
||||
markdown_result += f"})\n\n"
|
||||
markdown_result += f"[Read more]({res.get('url', '')})\n\n---\n\n"
|
||||
|
||||
json_result.append(
|
||||
self.create_json_message(
|
||||
{
|
||||
"title": res.get("title", ""),
|
||||
"date": res.get("date", ""),
|
||||
"body": res.get("body", ""),
|
||||
"url": res.get("url", ""),
|
||||
"image": res.get("image", ""),
|
||||
"source": res.get("source", ""),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return [self.create_text_message(markdown_result)] + json_result
|
||||
|
||||
def summary_results(self, user_id: str, content: str, query: str) -> str:
|
||||
prompt = SUMMARY_PROMPT.format(query=query, content=content)
|
||||
summary = self.invoke_model(
|
||||
user_id=user_id,
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content=prompt),
|
||||
],
|
||||
stop=[],
|
||||
)
|
||||
return summary.message.content
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
identity:
|
||||
name: ddgo_news
|
||||
author: Assistant
|
||||
label:
|
||||
en_US: DuckDuckGo News Search
|
||||
zh_Hans: DuckDuckGo 新闻搜索
|
||||
description:
|
||||
human:
|
||||
en_US: Perform news searches on DuckDuckGo and get results.
|
||||
zh_Hans: 在 DuckDuckGo 上进行新闻搜索并获取结果。
|
||||
llm: Perform news searches on DuckDuckGo and get results.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query String
|
||||
zh_Hans: 查询语句
|
||||
human_description:
|
||||
en_US: Search Query.
|
||||
zh_Hans: 搜索查询语句。
|
||||
llm_description: Key words for searching
|
||||
form: llm
|
||||
- name: max_results
|
||||
type: number
|
||||
required: true
|
||||
default: 5
|
||||
label:
|
||||
en_US: Max Results
|
||||
zh_Hans: 最大结果数量
|
||||
human_description:
|
||||
en_US: The Max Results
|
||||
zh_Hans: 最大结果数量
|
||||
form: form
|
||||
- name: timelimit
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: Day
|
||||
label:
|
||||
en_US: Current Day
|
||||
zh_Hans: 当天
|
||||
- value: Week
|
||||
label:
|
||||
en_US: Current Week
|
||||
zh_Hans: 本周
|
||||
- value: Month
|
||||
label:
|
||||
en_US: Current Month
|
||||
zh_Hans: 当月
|
||||
- value: Year
|
||||
label:
|
||||
en_US: Current Year
|
||||
zh_Hans: 今年
|
||||
label:
|
||||
en_US: Result Time Limit
|
||||
zh_Hans: 结果时间限制
|
||||
human_description:
|
||||
en_US: Use when querying results within a specific time range only.
|
||||
zh_Hans: 只查询一定时间范围内的结果时使用
|
||||
form: form
|
||||
- name: require_summary
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: Require Summary
|
||||
zh_Hans: 是否总结
|
||||
human_description:
|
||||
en_US: Whether to pass the news results to llm for summarization.
|
||||
zh_Hans: 是否需要将新闻结果传给大模型总结
|
||||
form: form
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
from typing import Any, ClassVar
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DuckDuckGoVideoSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a video search using DuckDuckGo search engine.
|
||||
"""
|
||||
|
||||
IFRAME_TEMPLATE: ClassVar[str] = """
|
||||
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
|
||||
max-width: 100%; border-radius: 8px;">
|
||||
<iframe
|
||||
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
|
||||
src="{src}"
|
||||
frameborder="0"
|
||||
allowfullscreen>
|
||||
</iframe>
|
||||
</div>"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
query_dict = {
|
||||
"keywords": tool_parameters.get("query"),
|
||||
"region": tool_parameters.get("region", "wt-wt"),
|
||||
"safesearch": tool_parameters.get("safesearch", "moderate"),
|
||||
"timelimit": tool_parameters.get("timelimit"),
|
||||
"resolution": tool_parameters.get("resolution"),
|
||||
"duration": tool_parameters.get("duration"),
|
||||
"license_videos": tool_parameters.get("license_videos"),
|
||||
"max_results": tool_parameters.get("max_results"),
|
||||
}
|
||||
|
||||
# Remove None values to use API defaults
|
||||
query_dict = {k: v for k, v in query_dict.items() if v is not None}
|
||||
|
||||
# Get proxy URL from parameters
|
||||
proxy_url = tool_parameters.get("proxy_url", "").strip()
|
||||
|
||||
response = DDGS().videos(**query_dict)
|
||||
|
||||
# Create HTML result with embedded iframes
|
||||
markdown_result = "\n\n"
|
||||
json_result = []
|
||||
|
||||
for res in response:
|
||||
title = res.get("title", "")
|
||||
embed_html = res.get("embed_html", "")
|
||||
description = res.get("description", "")
|
||||
content_url = res.get("content", "")
|
||||
|
||||
# Handle TED.com videos
|
||||
if not embed_html and "ted.com/talks" in content_url:
|
||||
embed_url = content_url.replace("www.ted.com", "embed.ted.com")
|
||||
if proxy_url:
|
||||
embed_url = f"{proxy_url}{embed_url}"
|
||||
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
|
||||
|
||||
# Original YouTube/other platform handling
|
||||
elif embed_html:
|
||||
embed_url = res.get("embed_url", "")
|
||||
if proxy_url and embed_url:
|
||||
embed_url = f"{proxy_url}{embed_url}"
|
||||
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
|
||||
|
||||
markdown_result += f"{title}\n\n"
|
||||
markdown_result += f"{embed_html}\n\n"
|
||||
markdown_result += "---\n\n"
|
||||
|
||||
json_result.append(self.create_json_message(res))
|
||||
|
||||
return [self.create_text_message(markdown_result)] + json_result
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
identity:
|
||||
name: ddgo_video
|
||||
author: Tao Wang
|
||||
label:
|
||||
en_US: DuckDuckGo Video Search
|
||||
zh_Hans: DuckDuckGo 视频搜索
|
||||
description:
|
||||
human:
|
||||
en_US: Search and embedded videos.
|
||||
zh_Hans: 搜索并嵌入视频
|
||||
llm: Search videos on duckduckgo and embed videos in iframe
|
||||
parameters:
|
||||
- name: query
|
||||
label:
|
||||
en_US: Query String
|
||||
zh_Hans: 查询语句
|
||||
type: string
|
||||
required: true
|
||||
human_description:
|
||||
en_US: Search Query
|
||||
zh_Hans: 搜索查询语句
|
||||
llm_description: Key words for searching
|
||||
form: llm
|
||||
- name: max_results
|
||||
label:
|
||||
en_US: Max Results
|
||||
zh_Hans: 最大结果数量
|
||||
type: number
|
||||
required: true
|
||||
default: 3
|
||||
minimum: 1
|
||||
maximum: 10
|
||||
human_description:
|
||||
en_US: The max results (1-10)
|
||||
zh_Hans: 最大结果数量(1-10)
|
||||
form: form
|
||||
- name: timelimit
|
||||
label:
|
||||
en_US: Result Time Limit
|
||||
zh_Hans: 结果时间限制
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: Day
|
||||
label:
|
||||
en_US: Current Day
|
||||
zh_Hans: 当天
|
||||
- value: Week
|
||||
label:
|
||||
en_US: Current Week
|
||||
zh_Hans: 本周
|
||||
- value: Month
|
||||
label:
|
||||
en_US: Current Month
|
||||
zh_Hans: 当月
|
||||
- value: Year
|
||||
label:
|
||||
en_US: Current Year
|
||||
zh_Hans: 今年
|
||||
human_description:
|
||||
en_US: Query results within a specific time range only
|
||||
zh_Hans: 只查询一定时间范围内的结果时使用
|
||||
form: form
|
||||
- name: duration
|
||||
label:
|
||||
en_US: Video Duration
|
||||
zh_Hans: 视频时长
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: short
|
||||
label:
|
||||
en_US: Short (<4 minutes)
|
||||
zh_Hans: 短视频(<4分钟)
|
||||
- value: medium
|
||||
label:
|
||||
en_US: Medium (4-20 minutes)
|
||||
zh_Hans: 中等(4-20分钟)
|
||||
- value: long
|
||||
label:
|
||||
en_US: Long (>20 minutes)
|
||||
zh_Hans: 长视频(>20分钟)
|
||||
human_description:
|
||||
en_US: Filter videos by duration
|
||||
zh_Hans: 按时长筛选视频
|
||||
form: form
|
||||
- name: proxy_url
|
||||
label:
|
||||
en_US: Proxy URL
|
||||
zh_Hans: 视频代理地址
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
human_description:
|
||||
en_US: Proxy URL
|
||||
zh_Hans: 视频代理地址
|
||||
form: form
|
||||
|
|
@ -38,7 +38,7 @@ def send_mail(parmas: SendEmailToolParameters):
|
|||
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.exception("send email failed: %s", e)
|
||||
logging.exception("send email failed")
|
||||
return False
|
||||
else: # NONE or TLS
|
||||
try:
|
||||
|
|
@ -49,5 +49,5 @@ def send_mail(parmas: SendEmailToolParameters):
|
|||
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.exception("send email failed: %s", e)
|
||||
logging.exception("send email failed")
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class SendMailTool(BuiltinTool):
|
|||
invoke tools
|
||||
"""
|
||||
sender = self.runtime.credentials.get("email_account", "")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
password = self.runtime.credentials.get("email_password", "")
|
||||
smtp_server = self.runtime.credentials.get("smtp_server", "")
|
||||
if not smtp_server:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class SendMailTool(BuiltinTool):
|
|||
invoke tools
|
||||
"""
|
||||
sender = self.runtime.credentials.get("email_account", "")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
password = self.runtime.credentials.get("email_password", "")
|
||||
smtp_server = self.runtime.credentials.get("smtp_server", "")
|
||||
if not smtp_server:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class WizperTool(BuiltinTool):
|
|||
version = tool_parameters.get("version", "3")
|
||||
|
||||
if audio_file.type != FileType.AUDIO:
|
||||
return [self.create_text_message("Not a valid audio file.")]
|
||||
return self.create_text_message("Not a valid audio file.")
|
||||
|
||||
api_key = self.runtime.credentials["fal_api_key"]
|
||||
|
||||
|
|
@ -31,9 +31,8 @@ class WizperTool(BuiltinTool):
|
|||
|
||||
try:
|
||||
audio_url = fal_client.upload(file_data, mime_type)
|
||||
|
||||
except Exception as e:
|
||||
return [self.create_text_message(f"Error uploading audio file: {str(e)}")]
|
||||
return self.create_text_message(f"Error uploading audio file: {str(e)}")
|
||||
|
||||
arguments = {
|
||||
"audio_url": audio_url,
|
||||
|
|
@ -49,4 +48,9 @@ class WizperTool(BuiltinTool):
|
|||
with_logs=False,
|
||||
)
|
||||
|
||||
return self.create_json_message(result)
|
||||
json_message = self.create_json_message(result)
|
||||
|
||||
text = result.get("text", "")
|
||||
text_message = self.create_text_message(text)
|
||||
|
||||
return [json_message, text_message]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GiteeAIToolEmbedding(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {self.runtime.credentials['api_key']}",
|
||||
}
|
||||
|
||||
payload = {"inputs": tool_parameters.get("inputs")}
|
||||
model = tool_parameters.get("model", "bge-m3")
|
||||
url = f"https://ai.gitee.com/api/serverless/{model}/embeddings"
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response:{response.text}")
|
||||
|
||||
return [self.create_text_message(response.content.decode("utf-8"))]
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
identity:
|
||||
name: embedding
|
||||
author: gitee_ai
|
||||
label:
|
||||
en_US: embedding
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Generate word embeddings using Serverless-supported models (compatible with OpenAI)
|
||||
llm: This tool is used to generate word embeddings from text input.
|
||||
parameters:
|
||||
- name: model
|
||||
type: string
|
||||
required: true
|
||||
in: path
|
||||
description:
|
||||
en_US: Supported Embedding (compatible with OpenAI) interface models
|
||||
enum:
|
||||
- bge-m3
|
||||
- bge-large-zh-v1.5
|
||||
- bge-small-zh-v1.5
|
||||
label:
|
||||
en_US: Service Model
|
||||
zh_Hans: 服务模型
|
||||
default: bge-m3
|
||||
form: form
|
||||
- name: inputs
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Input Text
|
||||
zh_Hans: 输入文本
|
||||
human_description:
|
||||
en_US: The text input used to generate embeddings.
|
||||
zh_Hans: 用于生成词向量的输入文本。
|
||||
llm_description: This text input will be used to generate embeddings.
|
||||
form: llm
|
||||
|
|
@ -6,7 +6,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
|
|||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GiteeAITool(BuiltinTool):
|
||||
class GiteeAIToolText2Image(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
identity:
|
||||
author: Yash Parmar
|
||||
author: Yash Parmar, Kalo Chin
|
||||
name: tavily
|
||||
label:
|
||||
en_US: Tavily
|
||||
zh_Hans: Tavily
|
||||
pt_BR: Tavily
|
||||
en_US: Tavily Search & Extract
|
||||
zh_Hans: Tavily 搜索和提取
|
||||
description:
|
||||
en_US: Tavily
|
||||
zh_Hans: Tavily
|
||||
pt_BR: Tavily
|
||||
en_US: A powerful AI-native search engine and web content extraction tool that provides highly relevant search results and raw content extraction from web pages.
|
||||
zh_Hans: 一个强大的原生AI搜索引擎和网页内容提取工具,提供高度相关的搜索结果和网页原始内容提取。
|
||||
icon: icon.png
|
||||
tags:
|
||||
- search
|
||||
|
|
@ -19,13 +17,10 @@ credentials_for_provider:
|
|||
label:
|
||||
en_US: Tavily API key
|
||||
zh_Hans: Tavily API key
|
||||
pt_BR: Tavily API key
|
||||
placeholder:
|
||||
en_US: Please input your Tavily API key
|
||||
zh_Hans: 请输入你的 Tavily API key
|
||||
pt_BR: Please input your Tavily API key
|
||||
help:
|
||||
en_US: Get your Tavily API key from Tavily
|
||||
zh_Hans: 从 TavilyApi 获取您的 Tavily API key
|
||||
pt_BR: Get your Tavily API key from Tavily
|
||||
url: https://docs.tavily.com/docs/welcome
|
||||
url: https://app.tavily.com/home
|
||||
|
|
|
|||
|
|
@ -0,0 +1,145 @@
|
|||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
TAVILY_API_URL = "https://api.tavily.com"
|
||||
|
||||
|
||||
class TavilyExtract:
|
||||
"""
|
||||
A class for extracting content from web pages using the Tavily Extract API.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for accessing the Tavily Extract API.
|
||||
|
||||
Methods:
|
||||
extract_content: Retrieves extracted content from the Tavily Extract API.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
def extract_content(self, params: dict[str, Any]) -> dict:
|
||||
"""
|
||||
Retrieves extracted content from the Tavily Extract API.
|
||||
|
||||
Args:
|
||||
params (Dict[str, Any]): The extraction parameters.
|
||||
|
||||
Returns:
|
||||
dict: The extracted content.
|
||||
|
||||
"""
|
||||
# Ensure required parameters are set
|
||||
if "api_key" not in params:
|
||||
params["api_key"] = self.api_key
|
||||
|
||||
# Process parameters
|
||||
processed_params = self._process_params(params)
|
||||
|
||||
response = requests.post(f"{TAVILY_API_URL}/extract", json=processed_params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _process_params(self, params: dict[str, Any]) -> dict:
|
||||
"""
|
||||
Processes and validates the extraction parameters.
|
||||
|
||||
Args:
|
||||
params (Dict[str, Any]): The extraction parameters.
|
||||
|
||||
Returns:
|
||||
dict: The processed parameters.
|
||||
"""
|
||||
processed_params = {}
|
||||
|
||||
# Process 'urls'
|
||||
if "urls" in params:
|
||||
urls = params["urls"]
|
||||
if isinstance(urls, str):
|
||||
processed_params["urls"] = [url.strip() for url in urls.replace(",", " ").split()]
|
||||
elif isinstance(urls, list):
|
||||
processed_params["urls"] = urls
|
||||
else:
|
||||
raise ValueError("The 'urls' parameter is required.")
|
||||
|
||||
# Only include 'api_key'
|
||||
processed_params["api_key"] = params.get("api_key", self.api_key)
|
||||
|
||||
return processed_params
|
||||
|
||||
|
||||
class TavilyExtractTool(BuiltinTool):
|
||||
"""
|
||||
A tool for extracting content from web pages using Tavily Extract.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the Tavily Extract tool with the given user ID and tool parameters.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (Dict[str, Any]): The parameters for the Tavily Extract tool.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the Tavily Extract tool invocation.
|
||||
"""
|
||||
urls = tool_parameters.get("urls", "")
|
||||
api_key = self.runtime.credentials.get("tavily_api_key")
|
||||
if not api_key:
|
||||
return self.create_text_message(
|
||||
"Tavily API key is missing. Please set the 'tavily_api_key' in credentials."
|
||||
)
|
||||
if not urls:
|
||||
return self.create_text_message("Please input at least one URL to extract.")
|
||||
|
||||
tavily_extract = TavilyExtract(api_key)
|
||||
try:
|
||||
raw_results = tavily_extract.extract_content(tool_parameters)
|
||||
except requests.HTTPError as e:
|
||||
return self.create_text_message(f"Error occurred while extracting content: {str(e)}")
|
||||
|
||||
if not raw_results.get("results"):
|
||||
return self.create_text_message("No content could be extracted from the provided URLs.")
|
||||
else:
|
||||
# Always return JSON message with all data
|
||||
json_message = self.create_json_message(raw_results)
|
||||
|
||||
# Create text message based on user-selected parameters
|
||||
text_message_content = self._format_results_as_text(raw_results)
|
||||
text_message = self.create_text_message(text=text_message_content)
|
||||
|
||||
return [json_message, text_message]
|
||||
|
||||
def _format_results_as_text(self, raw_results: dict) -> str:
|
||||
"""
|
||||
Formats the raw extraction results into a markdown text based on user-selected parameters.
|
||||
|
||||
Args:
|
||||
raw_results (dict): The raw extraction results.
|
||||
|
||||
Returns:
|
||||
str: The formatted markdown text.
|
||||
"""
|
||||
output_lines = []
|
||||
|
||||
for idx, result in enumerate(raw_results.get("results", []), 1):
|
||||
url = result.get("url", "")
|
||||
raw_content = result.get("raw_content", "")
|
||||
|
||||
output_lines.append(f"## Extracted Content {idx}: {url}\n")
|
||||
output_lines.append(f"**Raw Content:**\n{raw_content}\n")
|
||||
output_lines.append("---\n")
|
||||
|
||||
if raw_results.get("failed_results"):
|
||||
output_lines.append("## Failed URLs:\n")
|
||||
for failed in raw_results["failed_results"]:
|
||||
url = failed.get("url", "")
|
||||
error = failed.get("error", "Unknown error")
|
||||
output_lines.append(f"- {url}: {error}\n")
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
identity:
|
||||
name: tavily_extract
|
||||
author: Kalo Chin
|
||||
label:
|
||||
en_US: Tavily Extract
|
||||
zh_Hans: Tavily Extract
|
||||
description:
|
||||
human:
|
||||
en_US: A web extraction tool built specifically for AI agents (LLMs), delivering raw content from web pages.
|
||||
zh_Hans: 专为人工智能代理 (LLM) 构建的网页提取工具,提供网页的原始内容。
|
||||
llm: A tool for extracting raw content from web pages, designed for AI agents (LLMs).
|
||||
parameters:
|
||||
- name: urls
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: URLs
|
||||
zh_Hans: URLs
|
||||
human_description:
|
||||
en_US: A comma-separated list of URLs to extract content from.
|
||||
zh_Hans: 要从中提取内容的 URL 的逗号分隔列表。
|
||||
llm_description: A comma-separated list of URLs to extract content from.
|
||||
form: llm
|
||||
|
|
@ -17,8 +17,6 @@ class TavilySearch:
|
|||
|
||||
Methods:
|
||||
raw_results: Retrieves raw search results from the Tavily Search API.
|
||||
results: Retrieves cleaned search results from the Tavily Search API.
|
||||
clean_results: Cleans the raw search results.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
|
|
@ -35,63 +33,62 @@ class TavilySearch:
|
|||
dict: The raw search results.
|
||||
|
||||
"""
|
||||
# Ensure required parameters are set
|
||||
params["api_key"] = self.api_key
|
||||
if (
|
||||
"exclude_domains" in params
|
||||
and isinstance(params["exclude_domains"], str)
|
||||
and params["exclude_domains"] != "None"
|
||||
):
|
||||
params["exclude_domains"] = params["exclude_domains"].split()
|
||||
else:
|
||||
params["exclude_domains"] = []
|
||||
if (
|
||||
"include_domains" in params
|
||||
and isinstance(params["include_domains"], str)
|
||||
and params["include_domains"] != "None"
|
||||
):
|
||||
params["include_domains"] = params["include_domains"].split()
|
||||
else:
|
||||
params["include_domains"] = []
|
||||
|
||||
response = requests.post(f"{TAVILY_API_URL}/search", json=params)
|
||||
# Process parameters to ensure correct types
|
||||
processed_params = self._process_params(params)
|
||||
|
||||
response = requests.post(f"{TAVILY_API_URL}/search", json=processed_params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def results(self, params: dict[str, Any]) -> list[dict]:
|
||||
def _process_params(self, params: dict[str, Any]) -> dict:
|
||||
"""
|
||||
Retrieves cleaned search results from the Tavily Search API.
|
||||
Processes and validates the search parameters.
|
||||
|
||||
Args:
|
||||
params (Dict[str, Any]): The search parameters.
|
||||
|
||||
Returns:
|
||||
list: The cleaned search results.
|
||||
|
||||
dict: The processed parameters.
|
||||
"""
|
||||
raw_search_results = self.raw_results(params)
|
||||
return self.clean_results(raw_search_results["results"])
|
||||
processed_params = {}
|
||||
|
||||
def clean_results(self, results: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Cleans the raw search results.
|
||||
for key, value in params.items():
|
||||
if value is None or value == "None":
|
||||
continue
|
||||
if key in ["include_domains", "exclude_domains"]:
|
||||
if isinstance(value, str):
|
||||
# Split the string by commas or spaces and strip whitespace
|
||||
processed_params[key] = [domain.strip() for domain in value.replace(",", " ").split()]
|
||||
elif key in ["include_images", "include_image_descriptions", "include_answer", "include_raw_content"]:
|
||||
# Ensure boolean type
|
||||
if isinstance(value, str):
|
||||
processed_params[key] = value.lower() == "true"
|
||||
else:
|
||||
processed_params[key] = bool(value)
|
||||
elif key in ["max_results", "days"]:
|
||||
if isinstance(value, str):
|
||||
processed_params[key] = int(value)
|
||||
else:
|
||||
processed_params[key] = value
|
||||
elif key in ["search_depth", "topic", "query", "api_key"]:
|
||||
processed_params[key] = value
|
||||
else:
|
||||
# Unrecognized parameter
|
||||
pass
|
||||
|
||||
Args:
|
||||
results (list): The raw search results.
|
||||
# Set defaults if not present
|
||||
processed_params.setdefault("search_depth", "basic")
|
||||
processed_params.setdefault("topic", "general")
|
||||
processed_params.setdefault("max_results", 5)
|
||||
|
||||
Returns:
|
||||
list: The cleaned search results.
|
||||
# If topic is 'news', ensure 'days' is set
|
||||
if processed_params.get("topic") == "news":
|
||||
processed_params.setdefault("days", 3)
|
||||
|
||||
"""
|
||||
clean_results = []
|
||||
for result in results:
|
||||
clean_results.append(
|
||||
{
|
||||
"url": result["url"],
|
||||
"content": result["content"],
|
||||
}
|
||||
)
|
||||
# return clean results as a string
|
||||
return "\n".join([f"{res['url']}\n{res['content']}" for res in clean_results])
|
||||
return processed_params
|
||||
|
||||
|
||||
class TavilySearchTool(BuiltinTool):
|
||||
|
|
@ -111,14 +108,88 @@ class TavilySearchTool(BuiltinTool):
|
|||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the Tavily search tool invocation.
|
||||
"""
|
||||
query = tool_parameters.get("query", "")
|
||||
|
||||
api_key = self.runtime.credentials["tavily_api_key"]
|
||||
api_key = self.runtime.credentials.get("tavily_api_key")
|
||||
if not api_key:
|
||||
return self.create_text_message(
|
||||
"Tavily API key is missing. Please set the 'tavily_api_key' in credentials."
|
||||
)
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
return self.create_text_message("Please input a query.")
|
||||
|
||||
tavily_search = TavilySearch(api_key)
|
||||
results = tavily_search.results(tool_parameters)
|
||||
print(results)
|
||||
if not results:
|
||||
return self.create_text_message(f"No results found for '{query}' in Tavily")
|
||||
try:
|
||||
raw_results = tavily_search.raw_results(tool_parameters)
|
||||
except requests.HTTPError as e:
|
||||
return self.create_text_message(f"Error occurred while searching: {str(e)}")
|
||||
|
||||
if not raw_results.get("results"):
|
||||
return self.create_text_message(f"No results found for '{query}' in Tavily.")
|
||||
else:
|
||||
return self.create_text_message(text=results)
|
||||
# Always return JSON message with all data
|
||||
json_message = self.create_json_message(raw_results)
|
||||
|
||||
# Create text message based on user-selected parameters
|
||||
text_message_content = self._format_results_as_text(raw_results, tool_parameters)
|
||||
text_message = self.create_text_message(text=text_message_content)
|
||||
|
||||
return [json_message, text_message]
|
||||
|
||||
def _format_results_as_text(self, raw_results: dict, tool_parameters: dict[str, Any]) -> str:
|
||||
"""
|
||||
Formats the raw results into a markdown text based on user-selected parameters.
|
||||
|
||||
Args:
|
||||
raw_results (dict): The raw search results.
|
||||
tool_parameters (dict): The tool parameters selected by the user.
|
||||
|
||||
Returns:
|
||||
str: The formatted markdown text.
|
||||
"""
|
||||
output_lines = []
|
||||
|
||||
# Include answer if requested
|
||||
if tool_parameters.get("include_answer", False) and raw_results.get("answer"):
|
||||
output_lines.append(f"**Answer:** {raw_results['answer']}\n")
|
||||
|
||||
# Include images if requested
|
||||
if tool_parameters.get("include_images", False) and raw_results.get("images"):
|
||||
output_lines.append("**Images:**\n")
|
||||
for image in raw_results["images"]:
|
||||
if tool_parameters.get("include_image_descriptions", False) and "description" in image:
|
||||
output_lines.append(f"![{image['description']}]({image['url']})\n")
|
||||
else:
|
||||
output_lines.append(f"\n")
|
||||
|
||||
# Process each result
|
||||
if "results" in raw_results:
|
||||
for idx, result in enumerate(raw_results["results"], 1):
|
||||
title = result.get("title", "No Title")
|
||||
url = result.get("url", "")
|
||||
content = result.get("content", "")
|
||||
published_date = result.get("published_date", "")
|
||||
score = result.get("score", "")
|
||||
|
||||
output_lines.append(f"### Result {idx}: [{title}]({url})\n")
|
||||
|
||||
# Include published date if available and topic is 'news'
|
||||
if tool_parameters.get("topic") == "news" and published_date:
|
||||
output_lines.append(f"**Published Date:** {published_date}\n")
|
||||
|
||||
output_lines.append(f"**URL:** {url}\n")
|
||||
|
||||
# Include score (relevance)
|
||||
if score:
|
||||
output_lines.append(f"**Relevance Score:** {score}\n")
|
||||
|
||||
# Include content
|
||||
if content:
|
||||
output_lines.append(f"**Content:**\n{content}\n")
|
||||
|
||||
# Include raw content if requested
|
||||
if tool_parameters.get("include_raw_content", False) and result.get("raw_content"):
|
||||
output_lines.append(f"**Raw Content:**\n{result['raw_content']}\n")
|
||||
|
||||
# Add a separator
|
||||
output_lines.append("---\n")
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
|
|
|||
|
|
@ -2,28 +2,24 @@ identity:
|
|||
name: tavily_search
|
||||
author: Yash Parmar
|
||||
label:
|
||||
en_US: TavilySearch
|
||||
zh_Hans: TavilySearch
|
||||
pt_BR: TavilySearch
|
||||
en_US: Tavily Search
|
||||
zh_Hans: Tavily Search
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.
|
||||
en_US: A search engine tool built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.
|
||||
zh_Hans: 专为人工智能代理 (LLM) 构建的搜索引擎工具,可快速提供实时、准确和真实的结果。
|
||||
pt_BR: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.
|
||||
llm: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
en_US: Query
|
||||
zh_Hans: 查询
|
||||
human_description:
|
||||
en_US: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
pt_BR: used for searching
|
||||
llm_description: key words for searching
|
||||
en_US: The search query you want to execute with Tavily.
|
||||
zh_Hans: 您想用 Tavily 执行的搜索查询。
|
||||
llm_description: The search query.
|
||||
form: llm
|
||||
- name: search_depth
|
||||
type: select
|
||||
|
|
@ -31,122 +27,118 @@ parameters:
|
|||
label:
|
||||
en_US: Search Depth
|
||||
zh_Hans: 搜索深度
|
||||
pt_BR: Search Depth
|
||||
human_description:
|
||||
en_US: The depth of search results
|
||||
zh_Hans: 搜索结果的深度
|
||||
pt_BR: The depth of search results
|
||||
en_US: The depth of the search.
|
||||
zh_Hans: 搜索的深度。
|
||||
form: form
|
||||
options:
|
||||
- value: basic
|
||||
label:
|
||||
en_US: Basic
|
||||
zh_Hans: 基本
|
||||
pt_BR: Basic
|
||||
- value: advanced
|
||||
label:
|
||||
en_US: Advanced
|
||||
zh_Hans: 高级
|
||||
pt_BR: Advanced
|
||||
default: basic
|
||||
- name: topic
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Topic
|
||||
zh_Hans: 主题
|
||||
human_description:
|
||||
en_US: The category of the search.
|
||||
zh_Hans: 搜索的类别。
|
||||
form: form
|
||||
options:
|
||||
- value: general
|
||||
label:
|
||||
en_US: General
|
||||
zh_Hans: 一般
|
||||
- value: news
|
||||
label:
|
||||
en_US: News
|
||||
zh_Hans: 新闻
|
||||
default: general
|
||||
- name: days
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Days
|
||||
zh_Hans: 天数
|
||||
human_description:
|
||||
en_US: The number of days back from the current date to include in the search results (only applicable when "topic" is "news").
|
||||
zh_Hans: 从当前日期起向前追溯的天数,以包含在搜索结果中(仅当“topic”为“news”时适用)。
|
||||
form: form
|
||||
min: 1
|
||||
default: 3
|
||||
- name: max_results
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Max Results
|
||||
zh_Hans: 最大结果数
|
||||
human_description:
|
||||
en_US: The maximum number of search results to return.
|
||||
zh_Hans: 要返回的最大搜索结果数。
|
||||
form: form
|
||||
min: 1
|
||||
max: 20
|
||||
default: 5
|
||||
- name: include_images
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Include Images
|
||||
zh_Hans: 包含图片
|
||||
pt_BR: Include Images
|
||||
human_description:
|
||||
en_US: Include images in the search results
|
||||
zh_Hans: 在搜索结果中包含图片
|
||||
pt_BR: Include images in the search results
|
||||
en_US: Include a list of query-related images in the response.
|
||||
zh_Hans: 在响应中包含与查询相关的图片列表。
|
||||
form: form
|
||||
options:
|
||||
- value: 'true'
|
||||
label:
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
pt_BR: 'Yes'
|
||||
- value: 'false'
|
||||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
pt_BR: 'No'
|
||||
default: 'false'
|
||||
default: false
|
||||
- name: include_image_descriptions
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Include Image Descriptions
|
||||
zh_Hans: 包含图片描述
|
||||
human_description:
|
||||
en_US: When include_images is True, adds descriptive text for each image.
|
||||
zh_Hans: 当 include_images 为 True 时,为每个图像添加描述文本。
|
||||
form: form
|
||||
default: false
|
||||
- name: include_answer
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Include Answer
|
||||
zh_Hans: 包含答案
|
||||
pt_BR: Include Answer
|
||||
human_description:
|
||||
en_US: Include answers in the search results
|
||||
zh_Hans: 在搜索结果中包含答案
|
||||
pt_BR: Include answers in the search results
|
||||
en_US: Include a short answer to the original query in the response.
|
||||
zh_Hans: 在响应中包含对原始查询的简短回答。
|
||||
form: form
|
||||
options:
|
||||
- value: 'true'
|
||||
label:
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
pt_BR: 'Yes'
|
||||
- value: 'false'
|
||||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
pt_BR: 'No'
|
||||
default: 'false'
|
||||
default: false
|
||||
- name: include_raw_content
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Include Raw Content
|
||||
zh_Hans: 包含原始内容
|
||||
pt_BR: Include Raw Content
|
||||
human_description:
|
||||
en_US: Include raw content in the search results
|
||||
zh_Hans: 在搜索结果中包含原始内容
|
||||
pt_BR: Include raw content in the search results
|
||||
en_US: Include the cleaned and parsed HTML content of each search result.
|
||||
zh_Hans: 包含每个搜索结果的已清理和解析的HTML内容。
|
||||
form: form
|
||||
options:
|
||||
- value: 'true'
|
||||
label:
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
pt_BR: 'Yes'
|
||||
- value: 'false'
|
||||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
pt_BR: 'No'
|
||||
default: 'false'
|
||||
- name: max_results
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Max Results
|
||||
zh_Hans: 最大结果
|
||||
pt_BR: Max Results
|
||||
human_description:
|
||||
en_US: The number of maximum search results to return
|
||||
zh_Hans: 返回的最大搜索结果数
|
||||
pt_BR: The number of maximum search results to return
|
||||
form: form
|
||||
min: 1
|
||||
max: 20
|
||||
default: 5
|
||||
default: false
|
||||
- name: include_domains
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Include Domains
|
||||
zh_Hans: 包含域
|
||||
pt_BR: Include Domains
|
||||
human_description:
|
||||
en_US: A list of domains to specifically include in the search results
|
||||
zh_Hans: 在搜索结果中特别包含的域名列表
|
||||
pt_BR: A list of domains to specifically include in the search results
|
||||
en_US: A comma-separated list of domains to specifically include in the search results.
|
||||
zh_Hans: 要在搜索结果中特别包含的域的逗号分隔列表。
|
||||
form: form
|
||||
- name: exclude_domains
|
||||
type: string
|
||||
|
|
@ -154,9 +146,7 @@ parameters:
|
|||
label:
|
||||
en_US: Exclude Domains
|
||||
zh_Hans: 排除域
|
||||
pt_BR: Exclude Domains
|
||||
human_description:
|
||||
en_US: A list of domains to specifically exclude from the search results
|
||||
zh_Hans: 从搜索结果中特别排除的域名列表
|
||||
pt_BR: A list of domains to specifically exclude from the search results
|
||||
en_US: A comma-separated list of domains to specifically exclude from the search results.
|
||||
zh_Hans: 要从搜索结果中特别排除的域的逗号分隔列表。
|
||||
form: form
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="800px" height="800px" viewBox="0 -38 256 256" version="1.1" xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:xlink="http://www.w3.org/1999/xlink" preserveAspectRatio="xMidYMid">
|
||||
<g>
|
||||
<path d="M250.346231,28.0746923 C247.358133,17.0320558 238.732098,8.40602109 227.689461,5.41792308 C207.823743,0 127.868333,0 127.868333,0 C127.868333,0 47.9129229,0.164179487 28.0472049,5.58210256 C17.0045684,8.57020058 8.37853373,17.1962353 5.39043571,28.2388718 C-0.618533519,63.5374615 -2.94988224,117.322662 5.5546152,151.209308 C8.54271322,162.251944 17.1687479,170.877979 28.2113844,173.866077 C48.0771024,179.284 128.032513,179.284 128.032513,179.284 C128.032513,179.284 207.987923,179.284 227.853641,173.866077 C238.896277,170.877979 247.522312,162.251944 250.51041,151.209308 C256.847738,115.861464 258.801474,62.1091 250.346231,28.0746923 Z"
|
||||
fill="#FF0000">
|
||||
</path>
|
||||
<polygon fill="#FFFFFF" points="102.420513 128.06 168.749025 89.642 102.420513 51.224">
|
||||
</polygon>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.0 KiB |
|
|
@ -0,0 +1,81 @@
|
|||
from typing import Any, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class YouTubeTranscriptTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke the YouTube transcript tool
|
||||
"""
|
||||
try:
|
||||
# Extract parameters with defaults
|
||||
video_input = tool_parameters["video_id"]
|
||||
language = tool_parameters.get("language")
|
||||
output_format = tool_parameters.get("format", "text")
|
||||
preserve_formatting = tool_parameters.get("preserve_formatting", False)
|
||||
proxy = tool_parameters.get("proxy")
|
||||
cookies = tool_parameters.get("cookies")
|
||||
|
||||
# Extract video ID from URL if needed
|
||||
video_id = self._extract_video_id(video_input)
|
||||
|
||||
# Common kwargs for API calls
|
||||
kwargs = {"proxies": {"https": proxy} if proxy else None, "cookies": cookies}
|
||||
|
||||
try:
|
||||
if language:
|
||||
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id, **kwargs)
|
||||
try:
|
||||
transcript = transcript_list.find_transcript([language])
|
||||
except:
|
||||
# If requested language not found, try translating from English
|
||||
transcript = transcript_list.find_transcript(["en"]).translate(language)
|
||||
transcript_data = transcript.fetch()
|
||||
else:
|
||||
transcript_data = YouTubeTranscriptApi.get_transcript(
|
||||
video_id, preserve_formatting=preserve_formatting, **kwargs
|
||||
)
|
||||
|
||||
# Format output
|
||||
formatter_class = {
|
||||
"json": "JSONFormatter",
|
||||
"pretty": "PrettyPrintFormatter",
|
||||
"srt": "SRTFormatter",
|
||||
"vtt": "WebVTTFormatter",
|
||||
}.get(output_format)
|
||||
|
||||
if formatter_class:
|
||||
from youtube_transcript_api import formatters
|
||||
|
||||
formatter = getattr(formatters, formatter_class)()
|
||||
formatted_transcript = formatter.format_transcript(transcript_data)
|
||||
else:
|
||||
formatted_transcript = " ".join(entry["text"] for entry in transcript_data)
|
||||
|
||||
return self.create_text_message(text=formatted_transcript)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(text=f"Error getting transcript: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(text=f"Error processing request: {str(e)}")
|
||||
|
||||
def _extract_video_id(self, video_input: str) -> str:
|
||||
"""
|
||||
Extract video ID from URL or return as-is if already an ID
|
||||
"""
|
||||
if "youtube.com" in video_input or "youtu.be" in video_input:
|
||||
# Parse URL
|
||||
parsed_url = urlparse(video_input)
|
||||
if "youtube.com" in parsed_url.netloc:
|
||||
return parse_qs(parsed_url.query)["v"][0]
|
||||
else: # youtu.be
|
||||
return parsed_url.path[1:]
|
||||
return video_input # Assume it's already a video ID
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
identity:
|
||||
name: free_youtube_transcript
|
||||
author: Tao Wang
|
||||
label:
|
||||
en_US: Free YouTube Transcript API
|
||||
zh_Hans: 免费获取 YouTube 转录
|
||||
description:
|
||||
human:
|
||||
en_US: Get transcript from a YouTube video for free.
|
||||
zh_Hans: 免费获取 YouTube 视频的转录文案。
|
||||
llm: A tool for retrieving transcript from YouTube videos.
|
||||
parameters:
|
||||
- name: video_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Video ID/URL
|
||||
zh_Hans: 视频ID
|
||||
human_description:
|
||||
en_US: Used to define the video from which the transcript will be fetched. You can find the id in the video url. For example - https://www.youtube.com/watch?v=video_id.
|
||||
zh_Hans: 您要哪条视频的转录文案?您可以在视频链接中找到id。例如 - https://www.youtube.com/watch?v=video_id。
|
||||
llm_description: Used to define the video from which the transcript will be fetched. For example - https://www.youtube.com/watch?v=video_id.
|
||||
form: llm
|
||||
- name: language
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Language Code
|
||||
zh_Hans: 语言
|
||||
human_description:
|
||||
en_US: Language code (e.g. 'en', 'zh') for the transcript.
|
||||
zh_Hans: 字幕语言代码(如'en'、'zh')。留空则自动选择。
|
||||
llm_description: Used to set the language for transcripts.
|
||||
form: form
|
||||
- name: format
|
||||
type: select
|
||||
required: false
|
||||
default: text
|
||||
options:
|
||||
- value: text
|
||||
label:
|
||||
en_US: Plain Text
|
||||
zh_Hans: 纯文本
|
||||
- value: json
|
||||
label:
|
||||
en_US: JSON Format
|
||||
zh_Hans: JSON 格式
|
||||
- value: pretty
|
||||
label:
|
||||
en_US: Pretty Print Format
|
||||
zh_Hans: 美化格式
|
||||
- value: srt
|
||||
label:
|
||||
en_US: SRT Format
|
||||
zh_Hans: SRT 格式
|
||||
- value: vtt
|
||||
label:
|
||||
en_US: WebVTT Format
|
||||
zh_Hans: WebVTT 格式
|
||||
label:
|
||||
en_US: Output Format
|
||||
zh_Hans: 输出格式
|
||||
human_description:
|
||||
en_US: Format of the transcript output
|
||||
zh_Hans: 字幕输出格式
|
||||
llm_description: The format to output the transcript in. Options are text (plain text), json (raw transcript data), srt (SubRip format), or vtt (WebVTT format)
|
||||
form: form
|
||||
- name: preserve_formatting
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Preserve Formatting
|
||||
zh_Hans: 保留格式
|
||||
human_description:
|
||||
en_US: Keep HTML formatting elements like <i> (italics) and <b> (bold)
|
||||
zh_Hans: 保留HTML格式元素,如<i>(斜体)和<b>(粗体)
|
||||
llm_description: Whether to preserve HTML formatting elements in the transcript text
|
||||
form: form
|
||||
- name: proxy
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: HTTPS Proxy
|
||||
zh_Hans: HTTPS 代理
|
||||
human_description:
|
||||
en_US: HTTPS proxy URL (e.g. https://user:pass@domain:port)
|
||||
zh_Hans: HTTPS 代理地址(如 https://user:pass@domain:port)
|
||||
llm_description: HTTPS proxy to use for the request. Format should be https://user:pass@domain:port
|
||||
form: form
|
||||
- name: cookies
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Cookies File Path
|
||||
zh_Hans: Cookies 文件路径
|
||||
human_description:
|
||||
en_US: Path to cookies.txt file for accessing age-restricted videos
|
||||
zh_Hans: 用于访问年龄限制视频的 cookies.txt 文件路径
|
||||
llm_description: Path to a cookies.txt file containing YouTube cookies, needed for accessing age-restricted videos
|
||||
form: form
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue