diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index ddf976c47a..b0322dd2b2 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -3,8 +3,8 @@ cd web && npm install pipx install poetry -echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc -echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc +echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc +echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc diff --git a/.github/DISCUSSION_TEMPLATE/general.yml b/.github/DISCUSSION_TEMPLATE/general.yml index 5af61ea64c..487d719c85 100644 --- a/.github/DISCUSSION_TEMPLATE/general.yml +++ b/.github/DISCUSSION_TEMPLATE/general.yml @@ -9,7 +9,7 @@ body: required: true - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). required: true - - label: "请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true diff --git a/.github/DISCUSSION_TEMPLATE/help.yml b/.github/DISCUSSION_TEMPLATE/help.yml index abebaa9727..86de4057ae 100644 --- a/.github/DISCUSSION_TEMPLATE/help.yml +++ b/.github/DISCUSSION_TEMPLATE/help.yml @@ -9,7 +9,7 @@ body: required: true - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). required: true - - label: "请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true diff --git a/.github/DISCUSSION_TEMPLATE/suggestion.yml b/.github/DISCUSSION_TEMPLATE/suggestion.yml index 0893a10b2d..fa1b4e0251 100644 --- a/.github/DISCUSSION_TEMPLATE/suggestion.yml +++ b/.github/DISCUSSION_TEMPLATE/suggestion.yml @@ -9,7 +9,7 @@ body: required: true - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). required: true - - label: "请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 8824c5dba6..f767e8ba32 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -14,7 +14,7 @@ body: required: true - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). required: true - - label: "请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true @@ -22,7 +22,6 @@ body: - type: input attributes: label: Dify version - placeholder: 0.6.11 description: See about section in Dify console validations: required: true diff --git a/.github/ISSUE_TEMPLATE/document_issue.yml b/.github/ISSUE_TEMPLATE/document_issue.yml index 45ee37ca39..db8be32d95 100644 --- a/.github/ISSUE_TEMPLATE/document_issue.yml +++ b/.github/ISSUE_TEMPLATE/document_issue.yml @@ -12,7 +12,7 @@ body: required: true - label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). required: true - - label: "请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 1b0eaaf4ab..f6764d35ad 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -12,7 +12,7 @@ body: required: true - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). required: true - - label: "请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true diff --git a/.github/ISSUE_TEMPLATE/translation_issue.yml b/.github/ISSUE_TEMPLATE/translation_issue.yml index 440ea92616..5d2f020f45 100644 --- a/.github/ISSUE_TEMPLATE/translation_issue.yml +++ b/.github/ISSUE_TEMPLATE/translation_issue.yml @@ -12,14 +12,13 @@ body: required: true - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). required: true - - label: "请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true - type: input attributes: label: Dify version - placeholder: 0.3.21 description: Hover over system tray icon or look at Settings validations: required: true diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md index 1ce8436a78..6d5bfb205c 100644 --- a/CONTRIBUTING_JA.md +++ b/CONTRIBUTING_JA.md @@ -1,7 +1,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは素晴らしいことです。 私たちは、LLM アプリケーションの構築と管理のための最も直感的なワークフローを設計するという壮大な野望を持っています。人数も資金も限られている新興企業として、コミュニティからの支援は本当に重要です。 -私たちは現状を鑑み、機敏かつ迅速に開発をする必要がありますが、同時にあなたのようなコントリビューターの方々に、可能な限りスムーズな貢献体験をしていただきたいと思っています。そのためにこのコントリビュートガイドを作成しました。 +私たちは現状を鑑み、機敏かつ迅速に開発をする必要がありますが、同時にあなた様のようなコントリビューターの方々に、可能な限りスムーズな貢献体験をしていただきたいと思っています。そのためにこのコントリビュートガイドを作成しました。 コードベースやコントリビュータの方々と私たちがどのように仕事をしているのかに慣れていただき、楽しいパートにすぐに飛び込めるようにすることが目的です。 このガイドは Dify そのものと同様に、継続的に改善されています。実際のプロジェクトに遅れをとることがあるかもしれませんが、ご理解のほどよろしくお願いいたします。 @@ -14,13 +14,13 @@ Dify にコントリビュートしたいとお考えなのですね。それは ### 機能リクエスト -* 新しい機能要望を出す場合は、提案する機能が何を実現するものなのかを説明し、可能な限り多くのコンテキストを含めてください。[@perzeusss](https://github.com/perzeuss)は、あなたの要望を書き出すのに役立つ [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) を作ってくれました。気軽に試してみてください。 +* 新しい機能要望を出す場合は、提案する機能が何を実現するものなのかを説明し、可能な限り多くのコンテキストを含めてください。[@perzeusss](https://github.com/perzeuss)は、あなた様の要望を書き出すのに役立つ [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) を作ってくれました。気軽に試してみてください。 * 既存の課題から 1 つ選びたい場合は、その下にコメントを書いてください。 - 関連する方向で作業しているチームメンバーが参加します。すべてが良好であれば、コーディングを開始する許可が与えられます。私たちが変更を提案した場合にあなたの作業が無駄になることがないよう、それまでこの機能の作業を控えていただくようお願いいたします。 + 関連する方向で作業しているチームメンバーが参加します。すべてが良好であれば、コーディングを開始する許可が与えられます。私たちが変更を提案した場合にあなた様の作業が無駄になることがないよう、それまでこの機能の作業を控えていただくようお願いいたします。 - 提案された機能がどの分野に属するかによって、あなたは異なるチーム・メンバーと話をするかもしれません。以下は、各チームメンバーが現在取り組んでいる分野の概要です。 + 提案された機能がどの分野に属するかによって、あなた様は異なるチーム・メンバーと話をするかもしれません。以下は、各チームメンバーが現在取り組んでいる分野の概要です。 | Member | Scope | | --------------------------------------------------------------------------------------- | ------------------------------------ | @@ -153,7 +153,7 @@ Dify のバックエンドは[Flask](https://flask.palletsprojects.com/en/3.0.x/ いよいよ、私たちのリポジトリにプルリクエスト (PR) を提出する時が来ました。主要な機能については、まず `deploy/dev` ブランチにマージしてテストしてから `main` ブランチにマージします。 マージ競合などの問題が発生した場合、またはプル リクエストを開く方法がわからない場合は、[GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests) をチェックしてみてください。 -これで完了です!あなたの PR がマージされると、[README](https://github.com/langgenius/dify/blob/main/README.md) にコントリビューターとして紹介されます。 +これで完了です!あなた様の PR がマージされると、[README](https://github.com/langgenius/dify/blob/main/README.md) にコントリビューターとして紹介されます。 ## ヘルプを得る diff --git a/api/.env.example b/api/.env.example index 474798cef7..cf3a0f302d 100644 --- a/api/.env.example +++ b/api/.env.example @@ -183,6 +183,7 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 # Model Configuration MULTIMODAL_SEND_IMAGE_FORMAT=base64 +PROMPT_GENERATION_MAX_TOKENS=512 # Mail configuration, support: resend, smtp MAIL_TYPE= @@ -216,6 +217,7 @@ UNSTRUCTURED_API_KEY= SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTPS_URL= +SSRF_DEFAULT_MAX_RETRIES=3 BATCH_UPLOAD_LIMIT=10 KEYWORD_DATA_SOURCE_TYPE=database diff --git a/api/app.py b/api/app.py index 2c484ace85..50441cb81d 100644 --- a/api/app.py +++ b/api/app.py @@ -261,6 +261,7 @@ def after_request(response): @app.route('/health') def health(): return Response(json.dumps({ + 'pid': os.getpid(), 'status': 'ok', 'version': app.config['CURRENT_VERSION'] }), status=200, content_type="application/json") @@ -284,6 +285,7 @@ def threads(): }) return { + 'pid': os.getpid(), 'thread_num': num_threads, 'threads': thread_list } @@ -293,6 +295,7 @@ def threads(): def pool_stat(): engine = db.engine return { + 'pid': os.getpid(), 'pool_size': engine.pool.size(), 'checked_in_connections': engine.pool.checkedin(), 'checked_out_connections': engine.pool.checkedout(), diff --git a/api/commands.py b/api/commands.py index 6719539cc8..c7ffb47b51 100644 --- a/api/commands.py +++ b/api/commands.py @@ -249,8 +249,7 @@ def migrate_knowledge_vector_database(): create_count = 0 skipped_count = 0 total_count = 0 - config = current_app.config - vector_type = config.get('VECTOR_STORE') + vector_type = dify_config.VECTOR_STORE page = 1 while True: try: @@ -484,8 +483,7 @@ def convert_to_agent_apps(): @click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.') def add_qdrant_doc_id_index(field: str): click.echo(click.style('Start add qdrant doc_id index.', fg='green')) - config = current_app.config - vector_type = config.get('VECTOR_STORE') + vector_type = dify_config.VECTOR_STORE if vector_type != "qdrant": click.echo(click.style('Sorry, only support qdrant vector store.', fg='red')) return @@ -502,13 +500,15 @@ def add_qdrant_doc_id_index(field: str): from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig for binding in bindings: + if dify_config.QDRANT_URL is None: + raise ValueError('Qdrant url is required.') qdrant_config = QdrantConfig( - endpoint=config.get('QDRANT_URL'), - api_key=config.get('QDRANT_API_KEY'), + endpoint=dify_config.QDRANT_URL, + api_key=dify_config.QDRANT_API_KEY, root_path=current_app.root_path, - timeout=config.get('QDRANT_CLIENT_TIMEOUT'), - grpc_port=config.get('QDRANT_GRPC_PORT'), - prefer_grpc=config.get('QDRANT_GRPC_ENABLED') + timeout=dify_config.QDRANT_CLIENT_TIMEOUT, + grpc_port=dify_config.QDRANT_GRPC_PORT, + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED ) try: client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) diff --git a/api/configs/app_config.py b/api/configs/app_config.py index d1099a9036..a5a4fc788d 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -64,4 +64,6 @@ class DifyConfig( return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB' SSRF_PROXY_HTTP_URL: str | None = None - SSRF_PROXY_HTTPS_URL: str | None = None \ No newline at end of file + SSRF_PROXY_HTTPS_URL: str | None = None + + MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.') diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index a32b70bdc7..07688e9aeb 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,4 +1,5 @@ from typing import Any, Optional +from urllib.parse import quote_plus from pydantic import Field, NonNegativeInt, PositiveInt, computed_field from pydantic_settings import BaseSettings @@ -104,7 +105,7 @@ class DatabaseConfig: ).strip("&") db_extras = f"?{db_extras}" if db_extras else "" return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://" - f"{self.DB_USERNAME}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" + f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" f"{db_extras}") SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field( diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 4e228a70ff..6803775e20 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,3 +1,5 @@ +import os + from flask_login import current_user from flask_restful import Resource, reqparse @@ -28,13 +30,15 @@ class RuleGenerateApi(Resource): args = parser.parse_args() account = current_user + PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512')) try: rules = LLMGenerator.generate_rule_config( tenant_id=account.current_tenant_id, instruction=args['instruction'], model_config=args['model_config'], - no_variable=args['no_variable'] + no_variable=args['no_variable'], + rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 3a0e5ea94d..c135ece67e 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -71,7 +71,7 @@ class ResetPasswordApi(Resource): # AccountService.update_password(account, new_password) # todo: Send email - # MAILCHIMP_API_KEY = current_app.config['MAILCHIMP_TRANSACTIONAL_API_KEY'] + # MAILCHIMP_API_KEY = dify_config.MAILCHIMP_TRANSACTIONAL_API_KEY # mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY) # message = { @@ -92,7 +92,7 @@ class ResetPasswordApi(Resource): # 'message': message, # # required for transactional email # ' settings': { - # 'sandbox_mode': current_app.config['MAILCHIMP_SANDBOX_MODE'], + # 'sandbox_mode': dify_config.MAILCHIMP_SANDBOX_MODE, # }, # }) diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 10484c9027..9446f9d588 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -29,22 +29,21 @@ from services.app_generate_service import AppGenerateService logger = logging.getLogger(__name__) +workflow_run_fields = { + 'id': fields.String, + 'workflow_id': fields.String, + 'status': fields.String, + 'inputs': fields.Raw, + 'outputs': fields.Raw, + 'error': fields.String, + 'total_steps': fields.Integer, + 'total_tokens': fields.Integer, + 'created_at': fields.DateTime, + 'finished_at': fields.DateTime, + 'elapsed_time': fields.Float, +} -class WorkflowRunApi(Resource): - workflow_run_fields = { - 'id': fields.String, - 'workflow_id': fields.String, - 'status': fields.String, - 'inputs': fields.Raw, - 'outputs': fields.Raw, - 'error': fields.String, - 'total_steps': fields.Integer, - 'total_tokens': fields.Integer, - 'created_at': fields.DateTime, - 'finished_at': fields.DateTime, - 'elapsed_time': fields.Float, - } - +class WorkflowRunDetailApi(Resource): @validate_app_token @marshal_with(workflow_run_fields) def get(self, app_model: App, workflow_id: str): @@ -57,7 +56,7 @@ class WorkflowRunApi(Resource): workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first() return workflow_run - +class WorkflowRunApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): """ @@ -117,5 +116,6 @@ class WorkflowTaskStopApi(Resource): } -api.add_resource(WorkflowRunApi, '/workflows/run/', '/workflows/run') +api.add_resource(WorkflowRunApi, '/workflows/run') +api.add_resource(WorkflowRunDetailApi, '/workflows/run/') api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index dd2343d0b1..f929a979f1 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -5,9 +5,9 @@ from collections.abc import Generator from enum import Enum from typing import Any -from flask import current_app from sqlalchemy.orm import DeclarativeMeta +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, @@ -48,7 +48,7 @@ class AppQueueManager: :return: """ # wait for APP_MAX_EXECUTION_TIME seconds to stop listen - listen_timeout = current_app.config.get("APP_MAX_EXECUTION_TIME") + listen_timeout = dify_config.APP_MAX_EXECUTION_TIME start_time = time.time() last_ping_time = 0 while True: diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py index 0179d28887..d5cd0a589c 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/app/segments/__init__.py @@ -1,8 +1,21 @@ from .segment_group import SegmentGroup -from .segments import NoneSegment, Segment +from .segments import ( + ArrayAnySegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) from .types import SegmentType from .variables import ( - ArrayVariable, + ArrayAnyVariable, + ArrayFileVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, FileVariable, FloatVariable, IntegerVariable, @@ -20,11 +33,21 @@ __all__ = [ 'SecretVariable', 'FileVariable', 'StringVariable', - 'ArrayVariable', + 'ArrayAnyVariable', 'Variable', 'SegmentType', 'SegmentGroup', 'Segment', 'NoneSegment', 'NoneVariable', + 'IntegerSegment', + 'FloatSegment', + 'ObjectSegment', + 'ArrayAnySegment', + 'FileSegment', + 'StringSegment', + 'ArrayStringVariable', + 'ArrayNumberVariable', + 'ArrayObjectVariable', + 'ArrayFileVariable', ] diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py index 187042ec03..f62e44bf07 100644 --- a/api/core/app/segments/factory.py +++ b/api/core/app/segments/factory.py @@ -3,14 +3,25 @@ from typing import Any from core.file.file_obj import FileVar -from .segments import Segment, StringSegment +from .segments import ( + ArrayAnySegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) from .types import SegmentType from .variables import ( - ArrayVariable, + ArrayFileVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, FileVariable, FloatVariable, IntegerVariable, - NoneVariable, ObjectVariable, SecretVariable, StringVariable, @@ -28,40 +39,48 @@ def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable: match value_type: case SegmentType.STRING: return StringVariable.model_validate(m) + case SegmentType.SECRET: + return SecretVariable.model_validate(m) case SegmentType.NUMBER if isinstance(value, int): return IntegerVariable.model_validate(m) case SegmentType.NUMBER if isinstance(value, float): return FloatVariable.model_validate(m) - case SegmentType.SECRET: - return SecretVariable.model_validate(m) case SegmentType.NUMBER if not isinstance(value, float | int): raise ValueError(f'invalid number value {value}') + case SegmentType.FILE: + return FileVariable.model_validate(m) + case SegmentType.OBJECT if isinstance(value, dict): + return ObjectVariable.model_validate( + {**m, 'value': {k: build_variable_from_mapping(v) for k, v in value.items()}} + ) + case SegmentType.ARRAY_STRING if isinstance(value, list): + return ArrayStringVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) + case SegmentType.ARRAY_NUMBER if isinstance(value, list): + return ArrayNumberVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) + case SegmentType.ARRAY_OBJECT if isinstance(value, list): + return ArrayObjectVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) + case SegmentType.ARRAY_FILE if isinstance(value, list): + return ArrayFileVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) raise ValueError(f'not supported value type {value_type}') -def build_anonymous_variable(value: Any, /) -> Variable: - if value is None: - return NoneVariable(name='anonymous') - if isinstance(value, str): - return StringVariable(name='anonymous', value=value) - if isinstance(value, int): - return IntegerVariable(name='anonymous', value=value) - if isinstance(value, float): - return FloatVariable(name='anonymous', value=value) - if isinstance(value, dict): - # TODO: Limit the depth of the object - obj = {k: build_anonymous_variable(v) for k, v in value.items()} - return ObjectVariable(name='anonymous', value=obj) - if isinstance(value, list): - # TODO: Limit the depth of the array - elements = [build_anonymous_variable(v) for v in value] - return ArrayVariable(name='anonymous', value=elements) - if isinstance(value, FileVar): - return FileVariable(name='anonymous', value=value) - raise ValueError(f'not supported value {value}') - - def build_segment(value: Any, /) -> Segment: + if value is None: + return NoneSegment() if isinstance(value, str): return StringSegment(value=value) + if isinstance(value, int): + return IntegerSegment(value=value) + if isinstance(value, float): + return FloatSegment(value=value) + if isinstance(value, dict): + # TODO: Limit the depth of the object + obj = {k: build_segment(v) for k, v in value.items()} + return ObjectSegment(value=obj) + if isinstance(value, list): + # TODO: Limit the depth of the array + elements = [build_segment(v) for v in value] + return ArrayAnySegment(value=elements) + if isinstance(value, FileVar): + return FileSegment(value=value) raise ValueError(f'not supported value {value}') diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py index 21d1b89541..de6c796652 100644 --- a/api/core/app/segments/parser.py +++ b/api/core/app/segments/parser.py @@ -1,17 +1,18 @@ import re -from core.app.segments import SegmentGroup, factory from core.workflow.entities.variable_pool import VariablePool +from . import SegmentGroup, factory + VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}') def convert_template(*, template: str, variable_pool: VariablePool): parts = re.split(VARIABLE_PATTERN, template) segments = [] - for part in parts: + for part in filter(lambda x: x, parts): if '.' in part and (value := variable_pool.get(part.split('.'))): segments.append(value) else: segments.append(factory.build_segment(part)) - return SegmentGroup(segments=segments) + return SegmentGroup(value=segments) diff --git a/api/core/app/segments/segment_group.py b/api/core/app/segments/segment_group.py index 0d5176b885..b4ff09b6d3 100644 --- a/api/core/app/segments/segment_group.py +++ b/api/core/app/segments/segment_group.py @@ -1,19 +1,22 @@ -from pydantic import BaseModel - from .segments import Segment +from .types import SegmentType -class SegmentGroup(BaseModel): - segments: list[Segment] +class SegmentGroup(Segment): + value_type: SegmentType = SegmentType.GROUP + value: list[Segment] @property def text(self): - return ''.join([segment.text for segment in self.segments]) + return ''.join([segment.text for segment in self.value]) @property def log(self): - return ''.join([segment.log for segment in self.segments]) + return ''.join([segment.log for segment in self.value]) @property def markdown(self): - return ''.join([segment.markdown for segment in self.segments]) \ No newline at end of file + return ''.join([segment.markdown for segment in self.value]) + + def to_object(self): + return [segment.to_object() for segment in self.value] diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index e6bf6cc3a3..4227f154e6 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -1,7 +1,11 @@ +import json +from collections.abc import Mapping, Sequence from typing import Any from pydantic import BaseModel, ConfigDict, field_validator +from core.file.file_obj import FileVar + from .types import SegmentType @@ -34,12 +38,6 @@ class Segment(BaseModel): return str(self.value) def to_object(self) -> Any: - if isinstance(self.value, Segment): - return self.value.to_object() - if isinstance(self.value, list): - return [v.to_object() for v in self.value] - if isinstance(self.value, dict): - return {k: v.to_object() for k, v in self.value.items()} return self.value @@ -63,3 +61,80 @@ class NoneSegment(Segment): class StringSegment(Segment): value_type: SegmentType = SegmentType.STRING value: str + + +class FloatSegment(Segment): + value_type: SegmentType = SegmentType.NUMBER + value: float + + +class IntegerSegment(Segment): + value_type: SegmentType = SegmentType.NUMBER + value: int + + +class FileSegment(Segment): + value_type: SegmentType = SegmentType.FILE + # TODO: embed FileVar in this model. + value: FileVar + + @property + def markdown(self) -> str: + return self.value.to_markdown() + + +class ObjectSegment(Segment): + value_type: SegmentType = SegmentType.OBJECT + value: Mapping[str, Segment] + + @property + def text(self) -> str: + # TODO: Process variables. + return json.dumps(self.model_dump()['value'], ensure_ascii=False) + + @property + def log(self) -> str: + # TODO: Process variables. + return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + + @property + def markdown(self) -> str: + # TODO: Use markdown code block + return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + + def to_object(self): + return {k: v.to_object() for k, v in self.value.items()} + + +class ArraySegment(Segment): + @property + def markdown(self) -> str: + return '\n'.join(['- ' + item.markdown for item in self.value]) + + def to_object(self): + return [v.to_object() for v in self.value] + + +class ArrayAnySegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_ANY + value: Sequence[Segment] + + +class ArrayStringSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_STRING + value: Sequence[StringSegment] + + +class ArrayNumberSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_NUMBER + value: Sequence[FloatSegment | IntegerSegment] + + +class ArrayObjectSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_OBJECT + value: Sequence[ObjectSegment] + + +class ArrayFileSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_FILE + value: Sequence[FileSegment] diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py index ebcbf507c6..a371058ef5 100644 --- a/api/core/app/segments/types.py +++ b/api/core/app/segments/types.py @@ -6,6 +6,12 @@ class SegmentType(str, Enum): NUMBER = 'number' STRING = 'string' SECRET = 'secret' - ARRAY = 'array' + ARRAY_ANY = 'array[any]' + ARRAY_STRING = 'array[string]' + ARRAY_NUMBER = 'array[number]' + ARRAY_OBJECT = 'array[object]' + ARRAY_FILE = 'array[file]' OBJECT = 'object' FILE = 'file' + + GROUP = 'group' diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py index b020914d84..ac26e16542 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/app/segments/variables.py @@ -1,12 +1,21 @@ -import json -from collections.abc import Mapping, Sequence - from pydantic import Field -from core.file.file_obj import FileVar from core.helper import encrypter -from .segments import NoneSegment, Segment, StringSegment +from .segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) from .types import SegmentType @@ -27,53 +36,40 @@ class StringVariable(StringSegment, Variable): pass -class FloatVariable(Variable): - value_type: SegmentType = SegmentType.NUMBER - value: float +class FloatVariable(FloatSegment, Variable): + pass -class IntegerVariable(Variable): - value_type: SegmentType = SegmentType.NUMBER - value: int +class IntegerVariable(IntegerSegment, Variable): + pass -class ObjectVariable(Variable): - value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Variable] - - @property - def text(self) -> str: - # TODO: Process variables. - return json.dumps(self.model_dump()['value'], ensure_ascii=False) - - @property - def log(self) -> str: - # TODO: Process variables. - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) - - @property - def markdown(self) -> str: - # TODO: Use markdown code block - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) +class FileVariable(FileSegment, Variable): + pass -class ArrayVariable(Variable): - value_type: SegmentType = SegmentType.ARRAY - value: Sequence[Variable] - - @property - def markdown(self) -> str: - return '\n'.join(['- ' + item.markdown for item in self.value]) +class ObjectVariable(ObjectSegment, Variable): + pass -class FileVariable(Variable): - value_type: SegmentType = SegmentType.FILE - # TODO: embed FileVar in this model. - value: FileVar +class ArrayAnyVariable(ArrayAnySegment, Variable): + pass - @property - def markdown(self) -> str: - return self.value.to_markdown() + +class ArrayStringVariable(ArrayStringSegment, Variable): + pass + + +class ArrayNumberVariable(ArrayNumberSegment, Variable): + pass + + +class ArrayObjectVariable(ArrayObjectSegment, Variable): + pass + + +class ArrayFileVariable(ArrayFileSegment, Variable): + pass class SecretVariable(StringVariable): diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index 9e454f08d4..737a11e426 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -6,8 +6,7 @@ import os import time from typing import Optional -from flask import current_app - +from configs import dify_config from extensions.ext_storage import storage IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] @@ -23,7 +22,7 @@ class UploadFileParser: if upload_file.extension not in IMAGE_EXTENSIONS: return None - if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url: + if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == 'url' or force_url: return cls.get_signed_temp_image_url(upload_file.id) else: # get image file base64 @@ -44,13 +43,13 @@ class UploadFileParser: :param upload_file: UploadFile object :return: """ - base_url = current_app.config.get('FILES_URL') + base_url = dify_config.FILES_URL image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview' timestamp = str(int(time.time())) nonce = os.urandom(16).hex() data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = current_app.config['SECRET_KEY'].encode() + secret_key = dify_config.SECRET_KEY.encode() sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() @@ -68,7 +67,7 @@ class UploadFileParser: :return: """ data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = current_app.config['SECRET_KEY'].encode() + secret_key = dify_config.SECRET_KEY.encode() recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() @@ -77,4 +76,4 @@ class UploadFileParser: return False current_time = int(time.time()) - return current_time - int(timestamp) <= current_app.config.get('FILES_ACCESS_TIMEOUT') + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index 04675d85bb..dd1534c791 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -13,18 +13,10 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> :param file_name: the YAML file name, default to '_position.yaml' :return: a dict with name as key and index as value """ - position_file_name = os.path.join(folder_path, file_name) - if not position_file_name or not os.path.exists(position_file_name): - return {} - - positions = load_yaml_file(position_file_name, ignore_error=True) - position_map = {} - index = 0 - for _, name in enumerate(positions): - if name and isinstance(name, str): - position_map[name.strip()] = index - index += 1 - return position_map + position_file_path = os.path.join(folder_path, file_name) + yaml_content = load_yaml_file(file_path=position_file_path, default_value=[]) + positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()] + return {name: index for index, name in enumerate(positions)} def sort_by_position_map( diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 019b27f28a..14ca8e943c 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -1,48 +1,75 @@ """ Proxy requests to avoid SSRF """ +import logging import os +import time import httpx SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '') SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') +SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3')) proxies = { 'http://': SSRF_PROXY_HTTP_URL, 'https://': SSRF_PROXY_HTTPS_URL } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None +BACKOFF_FACTOR = 0.5 +STATUS_FORCELIST = [429, 500, 502, 503, 504] -def make_request(method, url, **kwargs): - if SSRF_PROXY_ALL_URL: - return httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) - elif proxies: - return httpx.request(method=method, url=url, proxies=proxies, **kwargs) - else: - return httpx.request(method=method, url=url, **kwargs) +def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + if "allow_redirects" in kwargs: + allow_redirects = kwargs.pop("allow_redirects") + if "follow_redirects" not in kwargs: + kwargs["follow_redirects"] = allow_redirects + + retries = 0 + while retries <= max_retries: + try: + if SSRF_PROXY_ALL_URL: + response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) + elif proxies: + response = httpx.request(method=method, url=url, proxies=proxies, **kwargs) + else: + response = httpx.request(method=method, url=url, **kwargs) + + if response.status_code not in STATUS_FORCELIST: + return response + else: + logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list") + + except httpx.RequestError as e: + logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") + + retries += 1 + if retries <= max_retries: + time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) + + raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}") -def get(url, **kwargs): - return make_request('GET', url, **kwargs) +def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('GET', url, max_retries=max_retries, **kwargs) -def post(url, **kwargs): - return make_request('POST', url, **kwargs) +def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('POST', url, max_retries=max_retries, **kwargs) -def put(url, **kwargs): - return make_request('PUT', url, **kwargs) +def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('PUT', url, max_retries=max_retries, **kwargs) -def patch(url, **kwargs): - return make_request('PATCH', url, **kwargs) +def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('PATCH', url, max_retries=max_retries, **kwargs) -def delete(url, **kwargs): - return make_request('DELETE', url, **kwargs) +def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('DELETE', url, max_retries=max_retries, **kwargs) -def head(url, **kwargs): - return make_request('HEAD', url, **kwargs) +def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('HEAD', url, max_retries=max_retries, **kwargs) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 83dbacbfcc..b20c6ed187 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -12,6 +12,7 @@ from flask import Flask, current_app from flask_login import current_user from sqlalchemy.orm.exc import ObjectDeletedError +from configs import dify_config from core.errors.error import ProviderTokenNotInitError from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager @@ -224,7 +225,7 @@ class IndexingRunner: features = FeatureService.get_features(tenant_id) if features.billing.enabled: count = len(extract_settings) - batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) + batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -427,7 +428,7 @@ class IndexingRunner: # The user-defined segmentation rule rules = json.loads(processing_rule.rules) segmentation = rules["segmentation"] - max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH']) + max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index d6a4399fc7..0b5029460a 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -118,7 +118,7 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict: + def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512) -> dict: output_parser = RuleConfigGeneratorOutputParser() error = "" @@ -130,7 +130,7 @@ class LLMGenerator: "error": "" } model_parameters = { - "max_tokens": 512, + "max_tokens": rule_config_max_tokens, "temperature": 0.01 } diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 04b539433c..0de216bf89 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -162,7 +162,7 @@ class AIModel(ABC): # traverse all model_schema_yaml_paths for model_schema_yaml_path in model_schema_yaml_paths: # read yaml data from yaml file - yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True) + yaml_data = load_yaml_file(model_schema_yaml_path) new_parameter_rules = [] for parameter_rule in yaml_data.get('parameter_rules', []): diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 51dd3b7e28..780460a3f7 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -44,7 +44,7 @@ class ModelProvider(ABC): # read provider schema from yaml file yaml_path = os.path.join(current_path, f'{provider_name}.yaml') - yaml_data = load_yaml_file(yaml_path, ignore_error=True) + yaml_data = load_yaml_file(yaml_path) try: # yaml_data to entity diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 34d1f64210..1911caa952 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -375,6 +375,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): continue delta = chunk.choices[0] + # NOTE: For fix https://github.com/langgenius/dify/issues/5790 + if delta.delta is None: + continue + # extract tool calls from response self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml index 3a79a929ba..86c8061dee 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml @@ -10,10 +10,14 @@ - cohere.command-text-v14 - cohere.command-r-plus-v1.0 - cohere.command-r-v1.0 +- meta.llama3-1-8b-instruct-v1:0 +- meta.llama3-1-70b-instruct-v1:0 +- meta.llama3-1-405b-instruct-v1:0 - meta.llama3-8b-instruct-v1:0 - meta.llama3-70b-instruct-v1:0 - meta.llama2-13b-chat-v1 - meta.llama2-70b-chat-v1 +- mistral.mistral-large-2407-v1:0 - mistral.mistral-small-2402-v1:0 - mistral.mistral-large-2402-v1:0 - mistral.mixtral-8x7b-instruct-v0:1 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-plus-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-plus-v1.0.yaml index 4ecf3dd2fd..3c0bb4e8d5 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-plus-v1.0.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-plus-v1.0.yaml @@ -3,8 +3,7 @@ label: en_US: Command R+ model_type: llm features: - #- multi-tool-call - - agent-thought + - tool-call #- stream-tool-call model_properties: mode: chat diff --git a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-v1.0.yaml index b7a12b480a..a34f48319f 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-v1.0.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-v1.0.yaml @@ -3,9 +3,7 @@ label: en_US: Command R model_type: llm features: - #- multi-tool-call - - agent-thought - #- stream-tool-call + - tool-call model_properties: mode: chat context_size: 128000 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 882d0b6352..ff34a116c7 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -17,7 +17,6 @@ from botocore.exceptions import ( ServiceNotInRegionError, UnknownServiceError, ) -from cohere import ChatMessage # local import from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -42,7 +41,6 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel logger = logging.getLogger(__name__) @@ -59,6 +57,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): {'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, {'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True}, {'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True}, + {'prefix': 'cohere.command-r', 'support_system_prompts': True, 'support_tool_use': True}, {'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False} ] @@ -94,86 +93,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel): model_info['model'] = model # invoke models via boto3 converse API return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools) - # invoke Cohere models via boto3 client - if "cohere.command-r" in model: - return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools) # invoke other models via boto3 client return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - - def _generate_cohere_chat( - self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: - cohere_llm = CohereLargeLanguageModel() - client_config = Config( - region_name=credentials["aws_region"] - ) - - runtime_client = boto3.client( - service_name='bedrock-runtime', - config=client_config, - aws_access_key_id=credentials["aws_access_key_id"], - aws_secret_access_key=credentials["aws_secret_access_key"] - ) - - extra_model_kwargs = {} - if stop: - extra_model_kwargs['stop_sequences'] = stop - - if tools: - tools = cohere_llm._convert_tools(tools) - model_parameters['tools'] = tools - - message, chat_histories, tool_results \ - = cohere_llm._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) - - if tool_results: - model_parameters['tool_results'] = tool_results - - payload = { - **model_parameters, - "message": message, - "chat_history": chat_histories, - } - - # need workaround for ai21 models which doesn't support streaming - if stream: - invoke = runtime_client.invoke_model_with_response_stream - else: - invoke = runtime_client.invoke_model - - def serialize(obj): - if isinstance(obj, ChatMessage): - return obj.__dict__ - raise TypeError(f"Type {type(obj)} not serializable") - - try: - body_jsonstr=json.dumps(payload, default=serialize) - response = invoke( - modelId=model, - contentType="application/json", - accept="*/*", - body=body_jsonstr - ) - except ClientError as ex: - error_code = ex.response['Error']['Code'] - full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" - raise self._map_client_to_invoke_error(error_code, full_error_msg) - - except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: - raise InvokeConnectionError(str(ex)) - - except UnknownServiceError as ex: - raise InvokeServerUnavailableError(str(ex)) - - except Exception as ex: - raise InvokeError(str(ex)) - - if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) - - return self._handle_generate_response(model, credentials, response, prompt_messages) - def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: @@ -208,14 +129,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if model_info['support_tool_use'] and tools: parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools) + try: + if stream: + response = bedrock_client.converse_stream(**parameters) + return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages) + else: + response = bedrock_client.converse(**parameters) + return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages) + except ClientError as ex: + error_code = ex.response['Error']['Code'] + full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" + raise self._map_client_to_invoke_error(error_code, full_error_msg) + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: + raise InvokeConnectionError(str(ex)) - if stream: - response = bedrock_client.converse_stream(**parameters) - return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages) - else: - response = bedrock_client.converse(**parameters) - return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages) + except UnknownServiceError as ex: + raise InvokeServerUnavailableError(str(ex)) + except Exception as ex: + raise InvokeError(str(ex)) def _handle_converse_response(self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]) -> LLMResult: """ @@ -558,7 +490,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel): except ClientError as ex: error_code = ex.response['Error']['Code'] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" - raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg))) except Exception as ex: @@ -571,38 +502,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param message: PromptMessage to convert. :return: String representation of the message. """ - - if model_prefix == "anthropic": - human_prompt_prefix = "\n\nHuman:" - human_prompt_postfix = "" - ai_prompt = "\n\nAssistant:" - - elif model_prefix == "meta": - # LLAMA3 - if model_name.startswith("llama3"): - human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" - human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" - ai_prompt = "\n\nAssistant:" - else: - # LLAMA2 - human_prompt_prefix = "\n[INST]" - human_prompt_postfix = "[\\INST]\n" - ai_prompt = "" - - elif model_prefix == "mistral": - human_prompt_prefix = "[INST]" - human_prompt_postfix = "[\\INST]\n" - ai_prompt = "\n\nAssistant:" - - elif model_prefix == "amazon": - human_prompt_prefix = "\n\nUser:" - human_prompt_postfix = "" - ai_prompt = "\n\nBot:" - - else: - human_prompt_prefix = "" - human_prompt_postfix = "" - ai_prompt = "" + human_prompt_prefix = "" + human_prompt_postfix = "" + ai_prompt = "" content = message.content @@ -653,13 +555,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): model_prefix = model.split('.')[0] model_name = model.split('.')[1] - if model_prefix == "amazon": - payload["textGenerationConfig"] = { **model_parameters } - payload["textGenerationConfig"]["stopSequences"] = ["User:"] - - payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) - - elif model_prefix == "ai21": + if model_prefix == "ai21": payload["temperature"] = model_parameters.get("temperature") payload["topP"] = model_parameters.get("topP") payload["maxTokens"] = model_parameters.get("maxTokens") @@ -671,28 +567,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel): payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")} if model_parameters.get("countPenalty"): payload["countPenalty"] = {model_parameters.get("countPenalty")} - - elif model_prefix == "mistral": - payload["temperature"] = model_parameters.get("temperature") - payload["top_p"] = model_parameters.get("top_p") - payload["max_tokens"] = model_parameters.get("max_tokens") - payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) - payload["stop"] = stop[:10] if stop else [] - - elif model_prefix == "anthropic": - payload = { **model_parameters } - payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) - payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else []) - + elif model_prefix == "cohere": payload = { **model_parameters } payload["prompt"] = prompt_messages[0].content payload["stream"] = stream - elif model_prefix == "meta": - payload = { **model_parameters } - payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name) - else: raise ValueError(f"Got unknown model prefix {model_prefix}") @@ -783,36 +663,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel): # get output text and calculate num tokens based on model / provider model_prefix = model.split('.')[0] - if model_prefix == "amazon": - output = response_body.get("results")[0].get("outputText").strip('\n') - prompt_tokens = response_body.get("inputTextTokenCount") - completion_tokens = response_body.get("results")[0].get("tokenCount") - - elif model_prefix == "ai21": + if model_prefix == "ai21": output = response_body.get('completions')[0].get('data').get('text') prompt_tokens = len(response_body.get("prompt").get("tokens")) completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) - - elif model_prefix == "anthropic": - output = response_body.get("completion") - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, output if output else '') elif model_prefix == "cohere": output = response_body.get("generations")[0].get("text") prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, output if output else '') - - elif model_prefix == "meta": - output = response_body.get("generation").strip('\n') - prompt_tokens = response_body.get("prompt_token_count") - completion_tokens = response_body.get("generation_token_count") - - elif model_prefix == "mistral": - output = response_body.get("outputs")[0].get("text") - prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count') - completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count') - + else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") @@ -883,26 +743,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel): payload = json.loads(chunk.get('bytes').decode()) model_prefix = model.split('.')[0] - if model_prefix == "amazon": - content_delta = payload.get("outputText").strip('\n') - finish_reason = payload.get("completion_reason") - - elif model_prefix == "anthropic": - content_delta = payload.get("completion") - finish_reason = payload.get("stop_reason") - - elif model_prefix == "cohere": + if model_prefix == "cohere": content_delta = payload.get("text") finish_reason = payload.get("finish_reason") - elif model_prefix == "mistral": - content_delta = payload.get('outputs')[0].get("text") - finish_reason = payload.get('outputs')[0].get("stop_reason") - - elif model_prefix == "meta": - content_delta = payload.get("generation").strip('\n') - finish_reason = payload.get("stop_reason") - else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response") diff --git a/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-405b-instruct-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-405b-instruct-v1.0.yaml new file mode 100644 index 0000000000..401de65f89 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-405b-instruct-v1.0.yaml @@ -0,0 +1,25 @@ +model: meta.llama3-1-405b-instruct-v1:0 +label: + en_US: Llama 3.1 405B Instruct +model_type: llm +model_properties: + mode: completion + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + - name: top_p + use_template: top_p + default: 0.9 + - name: max_gen_len + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 2048 +pricing: + input: '0.00532' + output: '0.016' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-70b-instruct-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-70b-instruct-v1.0.yaml new file mode 100644 index 0000000000..10bfa7b1d5 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-70b-instruct-v1.0.yaml @@ -0,0 +1,25 @@ +model: meta.llama3-1-70b-instruct-v1:0 +label: + en_US: Llama 3.1 Instruct 70B +model_type: llm +model_properties: + mode: completion + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + - name: top_p + use_template: top_p + default: 0.9 + - name: max_gen_len + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 2048 +pricing: + input: '0.00265' + output: '0.0035' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-8b-instruct-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-8b-instruct-v1.0.yaml new file mode 100644 index 0000000000..81cd53243f --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-8b-instruct-v1.0.yaml @@ -0,0 +1,25 @@ +model: meta.llama3-1-8b-instruct-v1:0 +label: + en_US: Llama 3.1 Instruct 8B +model_type: llm +model_properties: + mode: completion + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + - name: top_p + use_template: top_p + default: 0.9 + - name: max_gen_len + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 2048 +pricing: + input: '0.0003' + output: '0.0006' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-large-2407-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-large-2407-v1.0.yaml new file mode 100644 index 0000000000..19d7843a57 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-large-2407-v1.0.yaml @@ -0,0 +1,29 @@ +model: mistral.mistral-large-2407-v1:0 +label: + en_US: Mistral Large 2 (24.07) +model_type: llm +features: + - tool-call +model_properties: + mode: completion + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + required: false + default: 0.7 + - name: top_p + use_template: top_p + required: false + default: 1 + - name: max_tokens + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.003' + output: '0.009' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml index d50529926b..6832576524 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml +++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml @@ -23,7 +23,7 @@ parameter_rules: type: int default: 4096 min: 1 - max: 4096 + max: 8192 help: zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. @@ -57,6 +57,18 @@ parameter_rules: help: zh_Hans: 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。 en_US: A number between -2.0 and 2.0. If the value is positive, new tokens are penalized based on their frequency of occurrence in existing text, reducing the likelihood that the model will repeat the same content. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '1' output: '2' diff --git a/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml b/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml index 835a7716f7..812b51ddcd 100644 --- a/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml +++ b/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml @@ -18,6 +18,7 @@ help: en_US: https://console.cloud.tencent.com/cam/capi supported_model_types: - llm + - text-embedding configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 2b6d8e0047..8859dd72bd 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -14,6 +14,7 @@ from core.model_runtime.entities.message_entities import ( PromptMessage, PromptMessageTool, SystemPromptMessage, + ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.errors.invoke import InvokeError @@ -44,6 +45,17 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): "Stream": stream, **custom_parameters, } + # add Tools and ToolChoice + if (tools and len(tools) > 0): + params['ToolChoice'] = "auto" + params['Tools'] = [{ + "Type": "function", + "Function": { + "Name": tool.name, + "Description": tool.description, + "Parameters": json.dumps(tool.parameters) + } + } for tool in tools] request.from_json_string(json.dumps(params)) response = client.ChatCompletions(request) @@ -89,9 +101,43 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): def _convert_prompt_messages_to_dicts(self, prompt_messages: list[PromptMessage]) -> list[dict]: """Convert a list of PromptMessage objects to a list of dictionaries with 'Role' and 'Content' keys.""" - return [{"Role": message.role.value, "Content": message.content} for message in prompt_messages] + dict_list = [] + for message in prompt_messages: + if isinstance(message, AssistantPromptMessage): + tool_calls = message.tool_calls + if (tool_calls and len(tool_calls) > 0): + dict_tool_calls = [ + { + "Id": tool_call.id, + "Type": tool_call.type, + "Function": { + "Name": tool_call.function.name, + "Arguments": tool_call.function.arguments if (tool_call.function.arguments == "") else "{}" + } + } for tool_call in tool_calls] + + dict_list.append({ + "Role": message.role.value, + # fix set content = "" while tool_call request + # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time. + "Content": " ", # message.content if (message.content is not None) else "", + "ToolCalls": dict_tool_calls + }) + else: + dict_list.append({ "Role": message.role.value, "Content": message.content }) + elif isinstance(message, ToolPromptMessage): + tool_execute_result = { "result": message.content } + content =json.dumps(tool_execute_result, ensure_ascii=False) + dict_list.append({ "Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id }) + else: + dict_list.append({ "Role": message.role.value, "Content": message.content }) + return dict_list def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp): + + tool_call = None + tool_calls = [] + for index, event in enumerate(resp): logging.debug("_handle_stream_chat_response, event: %s", event) @@ -109,20 +155,54 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): usage = data.get('Usage', {}) prompt_tokens = usage.get('PromptTokens', 0) completion_tokens = usage.get('CompletionTokens', 0) - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + response_tool_calls = delta.get('ToolCalls') + if (response_tool_calls is not None): + new_tool_calls = self._extract_response_tool_calls(response_tool_calls) + if (len(new_tool_calls) > 0): + new_tool_call = new_tool_calls[0] + if (tool_call is None): tool_call = new_tool_call + elif (tool_call.id != new_tool_call.id): + tool_calls.append(tool_call) + tool_call = new_tool_call + else: + tool_call.function.name += new_tool_call.function.name + tool_call.function.arguments += new_tool_call.function.arguments + if (tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0): + tool_calls.append(tool_call) + tool_call = None assistant_prompt_message = AssistantPromptMessage( content=message_content, tool_calls=[] ) + # rewrite content = "" while tool_call to avoid show content on web page + if (len(tool_calls) > 0): assistant_prompt_message.content = "" + + # add tool_calls to assistant_prompt_message + if (finish_reason == 'tool_calls'): + assistant_prompt_message.tool_calls = tool_calls + tool_call = None + tool_calls = [] - delta_chunk = LLMResultChunkDelta( - index=index, - role=delta.get('Role', 'assistant'), - message=assistant_prompt_message, - usage=usage, - finish_reason=finish_reason, - ) + if (len(finish_reason) > 0): + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + delta_chunk = LLMResultChunkDelta( + index=index, + role=delta.get('Role', 'assistant'), + message=assistant_prompt_message, + usage=usage, + finish_reason=finish_reason, + ) + tool_call = None + tool_calls = [] + + else: + delta_chunk = LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + ) yield LLMResultChunk( model=model, @@ -177,12 +257,15 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): """ human_prompt = "\n\nHuman:" ai_prompt = "\n\nAssistant:" + tool_prompt = "\n\nTool:" content = message.content if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" + elif isinstance(message, ToolPromptMessage): + message_text = f"{tool_prompt} {content}" elif isinstance(message, SystemPromptMessage): message_text = content else: @@ -203,3 +286,30 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): return { InvokeError: [TencentCloudSDKException], } + + def _extract_response_tool_calls(self, + response_tool_calls: list[dict]) \ + -> list[AssistantPromptMessage.ToolCall]: + """ + Extract tool calls from response + + :param response_tool_calls: response tool calls + :return: list of tool calls + """ + tool_calls = [] + if response_tool_calls: + for response_tool_call in response_tool_calls: + response_function = response_tool_call.get('Function', {}) + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_function.get('Name', ''), + arguments=response_function.get('Arguments', '') + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.get('Id', 0), + type='function', + function=function + ) + tool_calls.append(tool_call) + + return tool_calls \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/__init__.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/hunyuan-text-embedding.yaml b/api/core/model_runtime/model_providers/hunyuan/text_embedding/hunyuan-text-embedding.yaml new file mode 100644 index 0000000000..ab014e4344 --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/hunyuan-text-embedding.yaml @@ -0,0 +1,5 @@ +model: hunyuan-embedding +model_type: text-embedding +model_properties: + context_size: 1024 + max_chunks: 1 diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py new file mode 100644 index 0000000000..64d8dcf795 --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -0,0 +1,173 @@ +import json +import logging +import time +from typing import Optional + +from tencentcloud.common import credential +from tencentcloud.common.exception import TencentCloudSDKException +from tencentcloud.common.profile.client_profile import ClientProfile +from tencentcloud.common.profile.http_profile import HttpProfile +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models + +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + +logger = logging.getLogger(__name__) + +class HunyuanTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Hunyuan text embedding model. + """ + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + + if model != 'hunyuan-embedding': + raise ValueError('Invalid model name') + + client = self._setup_hunyuan_client(credentials) + + embeddings = [] + token_usage = 0 + + for input in texts: + request = models.GetEmbeddingRequest() + params = { + "Input": input + } + request.from_json_string(json.dumps(params)) + response = client.GetEmbedding(request) + usage = response.Usage.TotalTokens + + embeddings.extend([data.Embedding for data in response.Data]) + token_usage += usage + + result = TextEmbeddingResult( + model=model, + embeddings=embeddings, + usage=self._calc_response_usage( + model=model, + credentials=credentials, + tokens=token_usage + ) + ) + + return result + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate credentials + """ + try: + client = self._setup_hunyuan_client(credentials) + + req = models.ChatCompletionsRequest() + params = { + "Model": model, + "Messages": [{ + "Role": "user", + "Content": "hello" + }], + "TopP": 1, + "Temperature": 0, + "Stream": False + } + req.from_json_string(json.dumps(params)) + client.ChatCompletions(req) + except Exception as e: + raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + + def _setup_hunyuan_client(self, credentials): + secret_id = credentials['secret_id'] + secret_key = credentials['secret_key'] + cred = credential.Credential(secret_id, secret_key) + httpProfile = HttpProfile() + httpProfile.endpoint = "hunyuan.tencentcloudapi.com" + clientProfile = ClientProfile() + clientProfile.httpProfile = httpProfile + client = hunyuan_client.HunyuanClient(cred, "", clientProfile) + return client + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeError: [TencentCloudSDKException], + } + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + # client = self._setup_hunyuan_client(credentials) + + num_tokens = 0 + for text in texts: + num_tokens += self._get_num_tokens_by_gpt2(text) + # use client.GetTokenCount to get num tokens + # request = models.GetTokenCountRequest() + # params = { + # "Prompt": text + # } + # request.from_json_string(json.dumps(params)) + # response = client.GetTokenCount(request) + # num_tokens += response.TokenCount + + return num_tokens \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/novita/llm/Nous-Hermes-2-Mixtral-8x7B-DPO.yaml b/api/core/model_runtime/model_providers/novita/llm/Nous-Hermes-2-Mixtral-8x7B-DPO.yaml index 8b19316473..7ff30458e2 100644 --- a/api/core/model_runtime/model_providers/novita/llm/Nous-Hermes-2-Mixtral-8x7B-DPO.yaml +++ b/api/core/model_runtime/model_providers/novita/llm/Nous-Hermes-2-Mixtral-8x7B-DPO.yaml @@ -34,3 +34,8 @@ parameter_rules: min: -2 max: 2 default: 0 +pricing: + input: '0.0027' + output: '0.0027' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/airoboros-l2-70b.yaml b/api/core/model_runtime/model_providers/novita/llm/airoboros-l2-70b.yaml new file mode 100644 index 0000000000..b599418461 --- /dev/null +++ b/api/core/model_runtime/model_providers/novita/llm/airoboros-l2-70b.yaml @@ -0,0 +1,41 @@ +model: jondurbin/airoboros-l2-70b +label: + zh_Hans: jondurbin/airoboros-l2-70b + en_US: jondurbin/airoboros-l2-70b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 2 + default: 1 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 512 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 +pricing: + input: '0.005' + output: '0.005' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/dolphin-mixtral-8x22b.yaml b/api/core/model_runtime/model_providers/novita/llm/dolphin-mixtral-8x22b.yaml new file mode 100644 index 0000000000..72a181f5d3 --- /dev/null +++ b/api/core/model_runtime/model_providers/novita/llm/dolphin-mixtral-8x22b.yaml @@ -0,0 +1,41 @@ +model: cognitivecomputations/dolphin-mixtral-8x22b +label: + zh_Hans: cognitivecomputations/dolphin-mixtral-8x22b + en_US: cognitivecomputations/dolphin-mixtral-8x22b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 16000 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 2 + default: 1 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 512 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 +pricing: + input: '0.009' + output: '0.009' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/gemma-2-9b-it.yaml b/api/core/model_runtime/model_providers/novita/llm/gemma-2-9b-it.yaml new file mode 100644 index 0000000000..d1749bc882 --- /dev/null +++ b/api/core/model_runtime/model_providers/novita/llm/gemma-2-9b-it.yaml @@ -0,0 +1,41 @@ +model: google/gemma-2-9b-it +label: + zh_Hans: google/gemma-2-9b-it + en_US: google/gemma-2-9b-it +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 2 + default: 1 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 512 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 +pricing: + input: '0.0008' + output: '0.0008' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/hermes-2-pro-llama-3-8b.yaml b/api/core/model_runtime/model_providers/novita/llm/hermes-2-pro-llama-3-8b.yaml new file mode 100644 index 0000000000..8b3228e56a --- /dev/null +++ b/api/core/model_runtime/model_providers/novita/llm/hermes-2-pro-llama-3-8b.yaml @@ -0,0 +1,41 @@ +model: nousresearch/hermes-2-pro-llama-3-8b +label: + zh_Hans: nousresearch/hermes-2-pro-llama-3-8b + en_US: nousresearch/hermes-2-pro-llama-3-8b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 2 + default: 1 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 512 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 +pricing: + input: '0.0014' + output: '0.0014' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/l3-70b-euryale-v2.1.yaml b/api/core/model_runtime/model_providers/novita/llm/l3-70b-euryale-v2.1.yaml new file mode 100644 index 0000000000..5e27941c52 --- /dev/null +++ b/api/core/model_runtime/model_providers/novita/llm/l3-70b-euryale-v2.1.yaml @@ -0,0 +1,41 @@ +model: sao10k/l3-70b-euryale-v2.1 +label: + zh_Hans: sao10k/l3-70b-euryale-v2.1 + en_US: sao10k/l3-70b-euryale-v2.1 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 16000 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 2 + default: 1 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 512 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 +pricing: + input: '0.0148' + output: '0.0148' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/llama-3-70b-instruct.yaml b/api/core/model_runtime/model_providers/novita/llm/llama-3-70b-instruct.yaml index 5298296de3..39709e1063 100644 --- a/api/core/model_runtime/model_providers/novita/llm/llama-3-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/novita/llm/llama-3-70b-instruct.yaml @@ -34,3 +34,8 @@ parameter_rules: min: -2 max: 2 default: 0 +pricing: + input: '0.0051' + output: '0.0074' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/llama-3-8b-instruct.yaml b/api/core/model_runtime/model_providers/novita/llm/llama-3-8b-instruct.yaml index 45e62ee52a..9b5e5df4d0 100644 --- a/api/core/model_runtime/model_providers/novita/llm/llama-3-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/novita/llm/llama-3-8b-instruct.yaml @@ -34,3 +34,8 @@ parameter_rules: min: -2 max: 2 default: 0 +pricing: + input: '0.00063' + output: '0.00063' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/llama-3.1-405b-instruct.yaml b/api/core/model_runtime/model_providers/novita/llm/llama-3.1-405b-instruct.yaml new file mode 100644 index 0000000000..c5a45271ae --- /dev/null +++ b/api/core/model_runtime/model_providers/novita/llm/llama-3.1-405b-instruct.yaml @@ -0,0 +1,41 @@ +model: meta-llama/llama-3.1-405b-instruct +label: + zh_Hans: meta-llama/llama-3.1-405b-instruct + en_US: meta-llama/llama-3.1-405b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 2 + default: 1 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 512 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 +pricing: + input: '0.03' + output: '0.05' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/llama-3.1-70b-instruct.yaml b/api/core/model_runtime/model_providers/novita/llm/llama-3.1-70b-instruct.yaml new file mode 100644 index 0000000000..3a5c29c40f --- /dev/null +++ b/api/core/model_runtime/model_providers/novita/llm/llama-3.1-70b-instruct.yaml @@ -0,0 +1,41 @@ +model: meta-llama/llama-3.1-70b-instruct +label: + zh_Hans: meta-llama/llama-3.1-70b-instruct + en_US: meta-llama/llama-3.1-70b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 2 + default: 1 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 512 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 +pricing: + input: '0.0055' + output: '0.0076' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/llama-3.1-8b-instruct.yaml b/api/core/model_runtime/model_providers/novita/llm/llama-3.1-8b-instruct.yaml new file mode 100644 index 0000000000..e6ef772a3f --- /dev/null +++ b/api/core/model_runtime/model_providers/novita/llm/llama-3.1-8b-instruct.yaml @@ -0,0 +1,41 @@ +model: meta-llama/llama-3.1-8b-instruct +label: + zh_Hans: meta-llama/llama-3.1-8b-instruct + en_US: meta-llama/llama-3.1-8b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 2 + default: 1 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 512 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 +pricing: + input: '0.001' + output: '0.001' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/lzlv_70b.yaml b/api/core/model_runtime/model_providers/novita/llm/lzlv_70b.yaml index 0facc0c112..0cc68a8c45 100644 --- a/api/core/model_runtime/model_providers/novita/llm/lzlv_70b.yaml +++ b/api/core/model_runtime/model_providers/novita/llm/lzlv_70b.yaml @@ -34,3 +34,8 @@ parameter_rules: min: -2 max: 2 default: 0 +pricing: + input: '0.0058' + output: '0.0078' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/midnight-rose-70b.yaml b/api/core/model_runtime/model_providers/novita/llm/midnight-rose-70b.yaml new file mode 100644 index 0000000000..19876bee17 --- /dev/null +++ b/api/core/model_runtime/model_providers/novita/llm/midnight-rose-70b.yaml @@ -0,0 +1,41 @@ +model: sophosympatheia/midnight-rose-70b +label: + zh_Hans: sophosympatheia/midnight-rose-70b + en_US: sophosympatheia/midnight-rose-70b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 2 + default: 1 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 512 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 +pricing: + input: '0.008' + output: '0.008' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/mistral-7b-instruct.yaml b/api/core/model_runtime/model_providers/novita/llm/mistral-7b-instruct.yaml new file mode 100644 index 0000000000..6fba47bcf0 --- /dev/null +++ b/api/core/model_runtime/model_providers/novita/llm/mistral-7b-instruct.yaml @@ -0,0 +1,41 @@ +model: mistralai/mistral-7b-instruct +label: + zh_Hans: mistralai/mistral-7b-instruct + en_US: mistralai/mistral-7b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 2 + default: 1 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 2048 + default: 512 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 +pricing: + input: '0.00059' + output: '0.00059' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/mythomax-l2-13b.yaml b/api/core/model_runtime/model_providers/novita/llm/mythomax-l2-13b.yaml index 28a8630ff2..7e4ac3ffe0 100644 --- a/api/core/model_runtime/model_providers/novita/llm/mythomax-l2-13b.yaml +++ b/api/core/model_runtime/model_providers/novita/llm/mythomax-l2-13b.yaml @@ -34,3 +34,8 @@ parameter_rules: min: -2 max: 2 default: 0 +pricing: + input: '0.00119' + output: '0.00119' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/nous-hermes-llama2-13b.yaml b/api/core/model_runtime/model_providers/novita/llm/nous-hermes-llama2-13b.yaml index ce714a118b..75671c414c 100644 --- a/api/core/model_runtime/model_providers/novita/llm/nous-hermes-llama2-13b.yaml +++ b/api/core/model_runtime/model_providers/novita/llm/nous-hermes-llama2-13b.yaml @@ -34,3 +34,8 @@ parameter_rules: min: -2 max: 2 default: 0 +pricing: + input: '0.0017' + output: '0.0017' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/openhermes-2.5-mistral-7b.yaml b/api/core/model_runtime/model_providers/novita/llm/openhermes-2.5-mistral-7b.yaml index 6cef39f847..8b0deba4f7 100644 --- a/api/core/model_runtime/model_providers/novita/llm/openhermes-2.5-mistral-7b.yaml +++ b/api/core/model_runtime/model_providers/novita/llm/openhermes-2.5-mistral-7b.yaml @@ -34,3 +34,8 @@ parameter_rules: min: -2 max: 2 default: 0 +pricing: + input: '0.0017' + output: '0.0017' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/llm/wizardlm-2-8x22b.yaml b/api/core/model_runtime/model_providers/novita/llm/wizardlm-2-8x22b.yaml index b3e3a03697..ef42568e8f 100644 --- a/api/core/model_runtime/model_providers/novita/llm/wizardlm-2-8x22b.yaml +++ b/api/core/model_runtime/model_providers/novita/llm/wizardlm-2-8x22b.yaml @@ -34,3 +34,8 @@ parameter_rules: min: -2 max: 2 default: 0 +pricing: + input: '0.0064' + output: '0.0064' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/novita/novita.yaml b/api/core/model_runtime/model_providers/novita/novita.yaml index ef6a863569..f634197989 100644 --- a/api/core/model_runtime/model_providers/novita/novita.yaml +++ b/api/core/model_runtime/model_providers/novita/novita.yaml @@ -1,6 +1,9 @@ provider: novita label: en_US: novita.ai +description: + en_US: An LLM API that matches various application scenarios with high cost-effectiveness. + zh_Hans: 适配多种海外应用场景的高性价比 LLM API icon_small: en_US: icon_s_en.svg icon_large: @@ -11,7 +14,7 @@ help: en_US: Get your API key from novita.ai zh_Hans: 从 novita.ai 获取 API Key url: - en_US: https://novita.ai/dashboard/key?utm_source=dify + en_US: https://novita.ai/settings#key-management?utm_source=dify&utm_medium=ch&utm_campaign=api supported_model_types: - llm configurate_methods: diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index 608ed897e0..d3fcf731f1 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -114,7 +114,8 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): # doc: https://platform.openai.com/docs/guides/text-to-speech credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) - if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials): + model_support_voice = [x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials)] + if not voice or voice not in model_support_voice: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml new file mode 100644 index 0000000000..d4431179e5 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml @@ -0,0 +1,30 @@ +model: deepseek-ai/DeepSeek-Coder-V2-Instruct +label: + en_US: deepseek-ai/DeepSeek-Coder-V2-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '1.33' + output: '1.33' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml index da58e822f9..3926568db6 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml @@ -1,11 +1,9 @@ model: deepseek-ai/deepseek-v2-chat label: - en_US: deepseek-ai/deepseek-v2-chat + en_US: deepseek-ai/DeepSeek-V2-Chat model_type: llm features: - - multi-tool-call - agent-thought - - stream-tool-call model_properties: mode: chat context_size: 32768 diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml index 115fc50b94..d6a4b21b66 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml @@ -1,11 +1,9 @@ model: zhipuai/glm4-9B-chat label: - en_US: zhipuai/glm4-9B-chat + en_US: THUDM/glm-4-9b-chat model_type: llm features: - - multi-tool-call - agent-thought - - stream-tool-call model_properties: mode: chat context_size: 32768 diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml index 75eca7720c..39624dc5b9 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml @@ -1,11 +1,9 @@ model: alibaba/Qwen2-57B-A14B-Instruct label: - en_US: alibaba/Qwen2-57B-A14B-Instruct + en_US: Qwen/Qwen2-57B-A14B-Instruct model_type: llm features: - - multi-tool-call - agent-thought - - stream-tool-call model_properties: mode: chat context_size: 32768 diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml index fcbc9e0b68..fb7ff6cb14 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml @@ -1,11 +1,9 @@ model: alibaba/Qwen2-72B-Instruct label: - en_US: alibaba/Qwen2-72B-Instruct + en_US: Qwen/Qwen2-72B-Instruct model_type: llm features: - - multi-tool-call - agent-thought - - stream-tool-call model_properties: mode: chat context_size: 32768 diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml index eda1d40642..efda4abbd9 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml @@ -1,11 +1,9 @@ model: alibaba/Qwen2-7B-Instruct label: - en_US: alibaba/Qwen2-7B-Instruct + en_US: Qwen/Qwen2-7B-Instruct model_type: llm features: - - multi-tool-call - agent-thought - - stream-tool-call model_properties: mode: chat context_size: 32768 diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml index 6656e663e9..864ba46f1a 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml @@ -1,11 +1,9 @@ model: 01-ai/Yi-1.5-34B-Chat label: - en_US: 01-ai/Yi-1.5-34B-Chat + en_US: 01-ai/Yi-1.5-34B-Chat-16K model_type: llm features: - - multi-tool-call - agent-thought - - stream-tool-call model_properties: mode: chat context_size: 16384 diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml index ba6e0c5113..38cd4197d4 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml @@ -3,9 +3,7 @@ label: en_US: 01-ai/Yi-1.5-6B-Chat model_type: llm features: - - multi-tool-call - agent-thought - - stream-tool-call model_properties: mode: chat context_size: 4096 diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml index 64be8998c5..042eeea81a 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml @@ -1,11 +1,9 @@ model: 01-ai/Yi-1.5-9B-Chat label: - en_US: 01-ai/Yi-1.5-9B-Chat + en_US: 01-ai/Yi-1.5-9B-Chat-16K model_type: llm features: - - multi-tool-call - agent-thought - - stream-tool-call model_properties: mode: chat context_size: 16384 diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py index 63f76fa8b5..a53f16c929 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py @@ -19,7 +19,7 @@ class SiliconflowProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) model_instance.validate_credentials( - model='deepseek-ai/deepseek-v2-chat', + model='deepseek-ai/DeepSeek-V2-Chat', credentials=credentials ) except CredentialsValidateFailedError as ex: diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 1f018c4078..6f768131fb 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -501,8 +501,7 @@ You should also complete the text started with ``` but not tell ``` directly. 'role': 'assistant', 'content': content if not rich_content else [{"text": content}], 'tool_calls': [tool_call.model_dump() for tool_call in - prompt_message.tool_calls] if prompt_message.tool_calls else [] - + prompt_message.tool_calls] if prompt_message.tool_calls else None }) elif isinstance(prompt_message, ToolPromptMessage): tongyi_messages.append({ diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 3587466952..9a4d8db4e2 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -6,6 +6,7 @@ from typing import Any, Optional from flask import Flask, current_app from pydantic import BaseModel, ConfigDict +from configs import dify_config from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import QueueMessageReplaceEvent from core.moderation.base import ModerationAction, ModerationOutputsResult @@ -20,8 +21,6 @@ class ModerationRule(BaseModel): class OutputModeration(BaseModel): - DEFAULT_BUFFER_SIZE: int = 300 - tenant_id: str app_id: str @@ -76,10 +75,10 @@ class OutputModeration(BaseModel): return final_output def start_thread(self) -> threading.Thread: - buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE)) + buffer_size = dify_config.MODERATION_BUFFER_SIZE thread = threading.Thread(target=self.worker, kwargs={ 'flask_app': current_app._get_current_object(), - 'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE + 'buffer_size': buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE }) thread.start() diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index ff15aa999b..4f6ab2fb94 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -298,34 +298,29 @@ class TraceTask: self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") def execute(self): - method_name, trace_info = self.preprocess() - return trace_info + return self.preprocess() def preprocess(self): - if self.trace_type == TraceTaskName.CONVERSATION_TRACE: - return TraceTaskName.CONVERSATION_TRACE, self.conversation_trace(**self.kwargs) - if self.trace_type == TraceTaskName.WORKFLOW_TRACE: - return TraceTaskName.WORKFLOW_TRACE, self.workflow_trace(self.workflow_run, self.conversation_id) - elif self.trace_type == TraceTaskName.MESSAGE_TRACE: - return TraceTaskName.MESSAGE_TRACE, self.message_trace(self.message_id) - elif self.trace_type == TraceTaskName.MODERATION_TRACE: - return TraceTaskName.MODERATION_TRACE, self.moderation_trace(self.message_id, self.timer, **self.kwargs) - elif self.trace_type == TraceTaskName.SUGGESTED_QUESTION_TRACE: - return TraceTaskName.SUGGESTED_QUESTION_TRACE, self.suggested_question_trace( + preprocess_map = { + TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), + TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(self.workflow_run, self.conversation_id), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), + TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( self.message_id, self.timer, **self.kwargs - ) - elif self.trace_type == TraceTaskName.DATASET_RETRIEVAL_TRACE: - return TraceTaskName.DATASET_RETRIEVAL_TRACE, self.dataset_retrieval_trace( + ), + TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( self.message_id, self.timer, **self.kwargs - ) - elif self.trace_type == TraceTaskName.TOOL_TRACE: - return TraceTaskName.TOOL_TRACE, self.tool_trace(self.message_id, self.timer, **self.kwargs) - elif self.trace_type == TraceTaskName.GENERATE_NAME_TRACE: - return TraceTaskName.GENERATE_NAME_TRACE, self.generate_name_trace( + ), + TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace( + self.message_id, self.timer, **self.kwargs + ), + TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs), + TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( self.conversation_id, self.timer, **self.kwargs - ) - else: - return '', {} + ), + } + + return preprocess_map.get(self.trace_type, lambda: None)() # process methods for different trace types def conversation_trace(self, **kwargs): diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 7f7c46e2dd..a3714c2fd3 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -2,9 +2,9 @@ import json from collections import defaultdict from typing import Any, Optional -from flask import current_app from pydantic import BaseModel +from configs import dify_config from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.keyword_base import BaseKeyword from core.rag.models.document import Document @@ -139,7 +139,7 @@ class Jieba(BaseKeyword): if keyword_table_dict: return keyword_table_dict['__data__']['table'] else: - keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE'] + keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( dataset_id=self.dataset.id, keyword_table='', diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index b679577c04..cfc533ed33 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -5,6 +5,7 @@ from uuid import uuid4 from pydantic import BaseModel, model_validator from pymilvus import MilvusClient, MilvusException, connections +from pymilvus.milvus_client import IndexParams from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings @@ -250,11 +251,15 @@ class MilvusVector(BaseVector): # Since primary field is auto-id, no need to track it self._fields.remove(Field.PRIMARY_KEY.value) + # Create Index params for the collection + index_params_obj = IndexParams() + index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params) + # Create the collection collection_name = self._collection_name - self._client.create_collection_with_schema(collection_name=collection_name, - schema=schema, index_param=index_params, - consistency_level=self._consistency_level) + self._client.create_collection(collection_name=collection_name, + schema=schema, index_params=index_params_obj, + consistency_level=self._consistency_level) redis_client.set(collection_exist_cache_key, 1, ex=3600) def _init_client(self, config) -> MilvusClient: diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index f65f57da60..4bd09b331d 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -55,7 +55,7 @@ CREATE TABLE IF NOT EXISTS {table_name} ( ) """ SQL_CREATE_INDEX = """ -CREATE INDEX idx_docs_{table_name} ON {table_name}(text) +CREATE INDEX IF NOT EXISTS idx_docs_{table_name} ON {table_name}(text) INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS ('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER sys.my_chinese_vgram_lexer') """ @@ -248,7 +248,7 @@ class OracleVector(BaseVector): def delete(self) -> None: with self._get_cursor() as cur: - cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") + cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints") def _create_collection(self, dimension: int): cache_key = f"vector_indexing_{self._collection_name}" diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index 2b16275dc8..f0c302a619 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -3,6 +3,7 @@ import os from typing import Optional import pandas as pd +from openpyxl import load_workbook from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -28,26 +29,48 @@ class ExcelExtractor(BaseExtractor): self._autodetect_encoding = autodetect_encoding def extract(self) -> list[Document]: - """ Load from Excel file in xls or xlsx format using Pandas.""" + """ Load from Excel file in xls or xlsx format using Pandas and openpyxl.""" documents = [] - # Determine the file extension file_extension = os.path.splitext(self._file_path)[-1].lower() - # Read each worksheet of an Excel file using Pandas + if file_extension == '.xlsx': - excel_file = pd.ExcelFile(self._file_path, engine='openpyxl') + wb = load_workbook(self._file_path, data_only=True) + for sheet_name in wb.sheetnames: + sheet = wb[sheet_name] + data = sheet.values + cols = next(data) + df = pd.DataFrame(data, columns=cols) + + df.dropna(how='all', inplace=True) + + for index, row in df.iterrows(): + page_content = [] + for col_index, (k, v) in enumerate(row.items()): + if pd.notna(v): + cell = sheet.cell(row=index + 2, + column=col_index + 1) # +2 to account for header and 1-based index + if cell.hyperlink: + value = f"[{v}]({cell.hyperlink.target})" + page_content.append(f'"{k}":"{value}"') + else: + page_content.append(f'"{k}":"{v}"') + documents.append(Document(page_content=';'.join(page_content), + metadata={'source': self._file_path})) + elif file_extension == '.xls': excel_file = pd.ExcelFile(self._file_path, engine='xlrd') + for sheet_name in excel_file.sheet_names: + df = excel_file.parse(sheet_name=sheet_name) + df.dropna(how='all', inplace=True) + + for _, row in df.iterrows(): + page_content = [] + for k, v in row.items(): + if pd.notna(v): + page_content.append(f'"{k}":"{v}"') + documents.append(Document(page_content=';'.join(page_content), + metadata={'source': self._file_path})) else: raise ValueError(f"Unsupported file extension: {file_extension}") - for sheet_name in excel_file.sheet_names: - df: pd.DataFrame = excel_file.parse(sheet_name=sheet_name) - - # filter out rows with all NaN values - df.dropna(how='all', inplace=True) - - # transform each row into a Document - documents += [Document(page_content=';'.join(f'"{k}":"{v}"' for k, v in row.items() if pd.notna(v)), - metadata={'source': self._file_path}, - ) for _, row in df.iterrows()] return documents diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index d01cf48fac..f7a08135f5 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -4,9 +4,8 @@ from pathlib import Path from typing import Union from urllib.parse import unquote -import requests - from configs import dify_config +from core.helper import ssrf_proxy from core.rag.extractor.csv_extractor import CSVExtractor from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -51,7 +50,7 @@ class ExtractProcessor: @classmethod def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: - response = requests.get(url, headers={ + response = ssrf_proxy.get(url, headers={ "User-Agent": USER_AGENT }) diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index faa1e64057..b24cf2e170 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -54,8 +54,16 @@ class MarkdownExtractor(BaseExtractor): current_header = None current_text = "" + code_block_flag = False for line in lines: + if line.startswith("```"): + code_block_flag = not code_block_flag + current_text += line + "\n" + continue + if code_block_flag: + current_text += line + "\n" + continue header_match = re.match(r"^#+\s", line) if header_match: if current_header is not None: diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index bb2720196f..89ffcf3347 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -30,7 +30,6 @@ class CogView3Tool(BuiltinTool): if not prompt: return self.create_text_message('Please input prompt') # get size - print(tool_parameters.get('prompt', 'square')) size = size_mapping[tool_parameters.get('size', 'square')] # get n n = tool_parameters.get('n', 1) @@ -58,8 +57,9 @@ class CogView3Tool(BuiltinTool): result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) - result.append(self.create_text_message( - f'\nGenerate image source to Seed ID: {seed_id}')) + result.append(self.create_json_message({ + "url": image.url, + })) return result @staticmethod diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.py b/api/core/tools/provider/builtin/firecrawl/firecrawl.py index adcb7ebdd6..24dc35759d 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.py @@ -1,22 +1,19 @@ from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.firecrawl.tools.crawl import CrawlTool +from core.tools.provider.builtin.firecrawl.tools.scrape import ScrapeTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController class FirecrawlProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - # Example validation using the Crawl tool - CrawlTool().fork_tool_runtime( + # Example validation using the ScrapeTool, only scraping title for minimize content + ScrapeTool().fork_tool_runtime( runtime={"credentials": credentials} ).invoke( user_id='', tool_parameters={ - "url": "https://example.com", - "includes": '', - "excludes": '', - "limit": 1, - "onlyMainContent": True, + "url": "https://google.com", + "onlyIncludeTags": 'title' } ) except Exception as e: diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml b/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml index 613a0e4679..a48b9d9f54 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml @@ -31,8 +31,5 @@ credentials_for_provider: label: en_US: Firecrawl server's Base URL zh_Hans: Firecrawl服务器的API URL - pt_BR: Firecrawl server's Base URL placeholder: - en_US: https://www.firecrawl.dev - zh_HansL: https://www.firecrawl.dev - pt_BR: https://www.firecrawl.dev + en_US: https://api.firecrawl.dev diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py index bfe3e7999d..3b3f78731b 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Mapping @@ -8,6 +9,7 @@ from requests.exceptions import HTTPError logger = logging.getLogger(__name__) + class FirecrawlApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key @@ -25,14 +27,16 @@ class FirecrawlApp: return headers def _request( - self, - method: str, - url: str, - data: Mapping[str, Any] | None = None, - headers: Mapping[str, str] | None = None, - retries: int = 3, - backoff_factor: float = 0.3, + self, + method: str, + url: str, + data: Mapping[str, Any] | None = None, + headers: Mapping[str, str] | None = None, + retries: int = 3, + backoff_factor: float = 0.3, ) -> Mapping[str, Any] | None: + if not headers: + headers = self._prepare_headers() for i in range(retries): try: response = requests.request(method, url, json=data, headers=headers) @@ -47,47 +51,51 @@ class FirecrawlApp: def scrape_url(self, url: str, **kwargs): endpoint = f'{self.base_url}/v0/scrape' - headers = self._prepare_headers() data = {'url': url, **kwargs} - response = self._request('POST', endpoint, data, headers) logger.debug(f"Sent request to {endpoint=} body={data}") + response = self._request('POST', endpoint, data) if response is None: raise HTTPError("Failed to scrape URL after multiple retries") return response def search(self, query: str, **kwargs): endpoint = f'{self.base_url}/v0/search' - headers = self._prepare_headers() data = {'query': query, **kwargs} - response = self._request('POST', endpoint, data, headers) logger.debug(f"Sent request to {endpoint=} body={data}") + response = self._request('POST', endpoint, data) if response is None: raise HTTPError("Failed to perform search after multiple retries") return response def crawl_url( - self, url: str, wait: bool = False, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs + self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs ): endpoint = f'{self.base_url}/v0/crawl' headers = self._prepare_headers(idempotency_key) - data = {'url': url, **kwargs['params']} - response = self._request('POST', endpoint, data, headers) + data = {'url': url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") + response = self._request('POST', endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate crawl after multiple retries") job_id: str = response['jobId'] if wait: return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval) - return job_id + return response def check_crawl_status(self, job_id: str): endpoint = f'{self.base_url}/v0/crawl/status/{job_id}' - headers = self._prepare_headers() - response = self._request('GET', endpoint, headers=headers) + response = self._request('GET', endpoint) if response is None: raise HTTPError(f"Failed to check status for job {job_id} after multiple retries") return response + def cancel_crawl_job(self, job_id: str): + endpoint = f'{self.base_url}/v0/crawl/cancel/{job_id}' + response = self._request('DELETE', endpoint) + if response is None: + raise HTTPError(f"Failed to cancel job {job_id} after multiple retries") + return response + def _monitor_job_status(self, job_id: str, poll_interval: int): while True: status = self.check_crawl_status(job_id) @@ -96,3 +104,21 @@ class FirecrawlApp: elif status['status'] == 'failed': raise HTTPError(f'Job {job_id} failed: {status["error"]}') time.sleep(poll_interval) + + +def get_array_params(tool_parameters: dict[str, Any], key): + param = tool_parameters.get(key) + if param: + return param.split(',') + + +def get_json_params(tool_parameters: dict[str, Any], key): + param = tool_parameters.get(key) + if param: + try: + # support both single quotes and double quotes + param = param.replace("'", '"') + param = json.loads(param) + except: + raise ValueError(f"Invalid {key} format.") + return param diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py index b000c1c6ce..08c40a4064 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py @@ -1,36 +1,48 @@ -import json -from typing import Any, Union +from typing import Any from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp +from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp, get_array_params, get_json_params from core.tools.tool.builtin_tool import BuiltinTool class CrawlTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], base_url=self.runtime.credentials['base_url']) + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + """ + the crawlerOptions and pageOptions comes from doc here: + https://docs.firecrawl.dev/api-reference/endpoint/crawl + """ + app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], + base_url=self.runtime.credentials['base_url']) + crawlerOptions = {} + pageOptions = {} - options = { - 'crawlerOptions': { - 'excludes': tool_parameters.get('excludes', '').split(',') if tool_parameters.get('excludes') else [], - 'includes': tool_parameters.get('includes', '').split(',') if tool_parameters.get('includes') else [], - 'limit': tool_parameters.get('limit', 5) - }, - 'pageOptions': { - 'onlyMainContent': tool_parameters.get('onlyMainContent', False) - } - } + wait_for_results = tool_parameters.get('wait_for_results', True) + + crawlerOptions['excludes'] = get_array_params(tool_parameters, 'excludes') + crawlerOptions['includes'] = get_array_params(tool_parameters, 'includes') + crawlerOptions['returnOnlyUrls'] = tool_parameters.get('returnOnlyUrls', False) + crawlerOptions['maxDepth'] = tool_parameters.get('maxDepth') + crawlerOptions['mode'] = tool_parameters.get('mode') + crawlerOptions['ignoreSitemap'] = tool_parameters.get('ignoreSitemap', False) + crawlerOptions['limit'] = tool_parameters.get('limit', 5) + crawlerOptions['allowBackwardCrawling'] = tool_parameters.get('allowBackwardCrawling', False) + crawlerOptions['allowExternalContentLinks'] = tool_parameters.get('allowExternalContentLinks', False) + + pageOptions['headers'] = get_json_params(tool_parameters, 'headers') + pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) + pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) + pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') + pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') + pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) + pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) + pageOptions['screenshot'] = tool_parameters.get('screenshot', False) + pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) crawl_result = app.crawl_url( - url=tool_parameters['url'], - params=options, - wait=True + url=tool_parameters['url'], + wait=wait_for_results, + crawlerOptions=crawlerOptions, + pageOptions=pageOptions ) - if not isinstance(crawl_result, str): - crawl_result = json.dumps(crawl_result, ensure_ascii=False, indent=4) - - if not crawl_result: - return self.create_text_message("Crawl request failed.") - - return self.create_text_message(crawl_result) + return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml index 3861670140..0c5399f973 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml @@ -3,76 +3,243 @@ identity: author: Richards Tu label: en_US: Crawl - zh_Hans: 爬取 + zh_Hans: 深度爬取 description: human: - en_US: Extract data from a website by crawling through a URL. - zh_Hans: 通过URL从网站中提取数据。 + en_US: Recursively search through a urls subdomains, and gather the content. + zh_Hans: 递归爬取一个网址的子域名,并收集内容。 llm: This tool initiates a web crawl to extract data from a specified URL. It allows configuring crawler options such as including or excluding URL patterns, generating alt text for images using LLMs (paid plan required), limiting the maximum number of pages to crawl, and returning only the main content of the page. The tool can return either a list of crawled documents or a list of URLs based on the provided options. parameters: - name: url type: string required: true label: - en_US: URL to crawl - zh_Hans: 要爬取的URL + en_US: Start URL + zh_Hans: 起始URL human_description: - en_US: The URL of the website to crawl and extract data from. - zh_Hans: 要爬取并提取数据的网站URL。 + en_US: The base URL to start crawling from. + zh_Hans: 要爬取网站的起始URL。 llm_description: The URL of the website that needs to be crawled. This is a required parameter. form: llm + - name: wait_for_results + type: boolean + default: true + label: + en_US: Wait For Results + zh_Hans: 等待爬取结果 + human_description: + en_US: If you choose not to wait, it will directly return a job ID. You can use this job ID to check the crawling results or cancel the crawling task, which is usually very useful for a large-scale crawling task. + zh_Hans: 如果选择不等待,则会直接返回一个job_id,可以通过job_id查询爬取结果或取消爬取任务,这通常对于一个大型爬取任务来说非常有用。 + form: form +############## Crawl Options ####################### - name: includes type: string required: false label: en_US: URL patterns to include zh_Hans: 要包含的URL模式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 human_description: - en_US: Specify URL patterns to include during the crawl. Only pages matching these patterns will be crawled, you can use ',' to separate multiple patterns. - zh_Hans: 指定爬取过程中要包含的URL模式。只有与这些模式匹配的页面才会被爬取。 + en_US: | + Only pages matching these patterns will be crawled. Example: blog/*, about/* + zh_Hans: 只有与这些模式匹配的页面才会被爬取。示例:blog/*, about/* form: form - default: '' - name: excludes type: string - required: false label: en_US: URL patterns to exclude zh_Hans: 要排除的URL模式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 human_description: - en_US: Specify URL patterns to exclude during the crawl. Pages matching these patterns will be skipped, you can use ',' to separate multiple patterns. - zh_Hans: 指定爬取过程中要排除的URL模式。匹配这些模式的页面将被跳过。 + en_US: | + Pages matching these patterns will be skipped. Example: blog/*, about/* + zh_Hans: 匹配这些模式的页面将被跳过。示例:blog/*, about/* + form: form + - name: returnOnlyUrls + type: boolean + default: false + label: + en_US: return Only Urls + zh_Hans: 仅返回URL + human_description: + en_US: | + If true, returns only the URLs as a list on the crawl status. Attention: the return response will be a list of URLs inside the data, not a list of documents. + zh_Hans: 只返回爬取到的网页链接,而不是网页内容本身。 + form: form + - name: maxDepth + type: number + label: + en_US: Maximum crawl depth + zh_Hans: 爬取深度 + human_description: + en_US: Maximum depth to crawl relative to the entered URL. A maxDepth of 0 scrapes only the entered URL. A maxDepth of 1 scrapes the entered URL and all pages one level deep. A maxDepth of 2 scrapes the entered URL and all pages up to two levels deep. Higher values follow the same pattern. + zh_Hans: 相对于输入的URL,爬取的最大深度。maxDepth为0时,仅抓取输入的URL。maxDepth为1时,抓取输入的URL以及所有一级深层页面。maxDepth为2时,抓取输入的URL以及所有两级深层页面。更高值遵循相同模式。 + form: form + min: 0 + - name: mode + type: select + required: false + form: form + options: + - value: default + label: + en_US: default + - value: fast + label: + en_US: fast + default: default + label: + en_US: Crawl Mode + zh_Hans: 爬取模式 + human_description: + en_US: The crawling mode to use. Fast mode crawls 4x faster websites without sitemap, but may not be as accurate and shouldn't be used in heavy js-rendered websites. + zh_Hans: 使用fast模式将不会使用其站点地图,比普通模式快4倍,但是可能不够准确,也不适用于大量js渲染的网站。 + - name: ignoreSitemap + type: boolean + default: false + label: + en_US: ignore Sitemap + zh_Hans: 忽略站点地图 + human_description: + en_US: Ignore the website sitemap when crawling. + zh_Hans: 爬取时忽略网站站点地图。 form: form - default: 'blog/*' - name: limit type: number required: false label: - en_US: Maximum number of pages to crawl + en_US: Maximum pages to crawl zh_Hans: 最大爬取页面数 human_description: en_US: Specify the maximum number of pages to crawl. The crawler will stop after reaching this limit. zh_Hans: 指定要爬取的最大页面数。爬虫将在达到此限制后停止。 form: form min: 1 - max: 20 default: 5 + - name: allowBackwardCrawling + type: boolean + default: false + label: + en_US: allow Backward Crawling + zh_Hans: 允许向后爬取 + human_description: + en_US: Enables the crawler to navigate from a specific URL to previously linked pages. For instance, from 'example.com/product/123' back to 'example.com/product' + zh_Hans: 使爬虫能够从特定URL导航到之前链接的页面。例如,从'example.com/product/123'返回到'example.com/product' + form: form + - name: allowExternalContentLinks + type: boolean + default: false + label: + en_US: allow External Content Links + zh_Hans: 允许爬取外链 + human_description: + en_US: Allows the crawler to follow links to external websites. + zh_Hans: + form: form +############## Page Options ####################### + - name: headers + type: string + label: + en_US: headers + zh_Hans: 请求头 + human_description: + en_US: | + Headers to send with the request. Can be used to send cookies, user-agent, etc. Example: {"cookies": "testcookies"} + zh_Hans: | + 随请求发送的头部。可以用来发送cookies、用户代理等。示例:{"cookies": "testcookies"} + placeholder: + en_US: Please enter an object that can be serialized in JSON + zh_Hans: 请输入可以json序列化的对象 + form: form + - name: includeHtml + type: boolean + default: false + label: + en_US: include Html + zh_Hans: 包含HTML + human_description: + en_US: Include the HTML version of the content on page. Will output a html key in the response. + zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 + form: form + - name: includeRawHtml + type: boolean + default: false + label: + en_US: include Raw Html + zh_Hans: 包含原始HTML + human_description: + en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. + zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 + form: form + - name: onlyIncludeTags + type: string + label: + en_US: only Include Tags + zh_Hans: 仅抓取这些标签 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + form: form - name: onlyMainContent type: boolean - required: false + default: false label: - en_US: Only return the main content of the page - zh_Hans: 仅返回页面的主要内容 + en_US: only Main Content + zh_Hans: 仅抓取主要内容 human_description: - en_US: If enabled, the crawler will only return the main content of the page, excluding headers, navigation, footers, etc. - zh_Hans: 如果启用,爬虫将仅返回页面的主要内容,不包括标题、导航、页脚等。 + en_US: Only return the main content of the page excluding headers, navs, footers, etc. + zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 + form: form + - name: removeTags + type: string + label: + en_US: remove Tags + zh_Hans: 要移除这些标签 + human_description: + en_US: | + Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + form: form + - name: replaceAllPathsWithAbsolutePaths + type: boolean + default: false + label: + en_US: All AbsolutePaths + zh_Hans: 使用绝对路径 + human_description: + en_US: Replace all relative paths with absolute paths for images and links. + zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。 + form: form + - name: screenshot + type: boolean + default: false + label: + en_US: screenshot + zh_Hans: 截图 + human_description: + en_US: Include a screenshot of the top of the page that you are scraping. + zh_Hans: 提供正在抓取的页面的顶部的截图。 + form: form + - name: waitFor + type: number + min: 0 + label: + en_US: wait For + zh_Hans: 等待时间 + human_description: + en_US: Wait x amount of milliseconds for the page to load to fetch content. + zh_Hans: 等待x毫秒以使页面加载并获取内容。 form: form - options: - - value: 'true' - label: - en_US: 'Yes' - zh_Hans: 是 - - value: 'false' - label: - en_US: 'No' - zh_Hans: 否 - default: 'false' diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py new file mode 100644 index 0000000000..fa6c1f87ee --- /dev/null +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp +from core.tools.tool.builtin_tool import BuiltinTool + + +class CrawlJobTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], + base_url=self.runtime.credentials['base_url']) + operation = tool_parameters.get('operation', 'get') + if operation == 'get': + result = app.check_crawl_status(job_id=tool_parameters['job_id']) + elif operation == 'cancel': + result = app.cancel_crawl_job(job_id=tool_parameters['job_id']) + else: + raise ValueError(f'Invalid operation: {operation}') + + return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.yaml b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.yaml new file mode 100644 index 0000000000..78008e4ad4 --- /dev/null +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.yaml @@ -0,0 +1,37 @@ +identity: + name: crawl_job + author: hjlarry + label: + en_US: Crawl Job + zh_Hans: 爬取任务处理 +description: + human: + en_US: Retrieve the scraping results based on the job ID, or cancel the scraping task. + zh_Hans: 根据爬取任务ID获取爬取结果,或者取消爬取任务 + llm: Retrieve the scraping results based on the job ID, or cancel the scraping task. +parameters: + - name: job_id + type: string + required: true + label: + en_US: Job ID + human_description: + en_US: Set wait_for_results to false in the Crawl tool can get the job ID. + zh_Hans: 在深度爬取工具中将等待爬取结果设置为否可以获取Job ID。 + llm_description: Set wait_for_results to false in the Crawl tool can get the job ID. + form: llm + - name: operation + type: select + required: true + options: + - value: get + label: + en_US: get crawl status + - value: cancel + label: + en_US: cancel crawl job + label: + en_US: operation + zh_Hans: 操作 + llm_description: choose the operation to perform. `get` is for getting the crawl status, `cancel` is for cancelling the crawl job. + form: llm diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py index 3a78dce8d0..91412da548 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py @@ -1,26 +1,39 @@ -import json -from typing import Any, Union +from typing import Any from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp +from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp, get_array_params, get_json_params from core.tools.tool.builtin_tool import BuiltinTool class ScrapeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], base_url=self.runtime.credentials['base_url']) - crawl_result = app.scrape_url( - url=tool_parameters['url'], - wait=True - ) + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + """ + the pageOptions and extractorOptions comes from doc here: + https://docs.firecrawl.dev/api-reference/endpoint/scrape + """ + app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], + base_url=self.runtime.credentials['base_url']) - if isinstance(crawl_result, dict): - result_message = json.dumps(crawl_result, ensure_ascii=False, indent=4) - else: - result_message = str(crawl_result) + pageOptions = {} + extractorOptions = {} - if not crawl_result: - return self.create_text_message("Scrape request failed.") + pageOptions['headers'] = get_json_params(tool_parameters, 'headers') + pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) + pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) + pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') + pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') + pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) + pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) + pageOptions['screenshot'] = tool_parameters.get('screenshot', False) + pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) - return self.create_text_message(result_message) + extractorOptions['mode'] = tool_parameters.get('mode', '') + extractorOptions['extractionPrompt'] = tool_parameters.get('extractionPrompt', '') + extractorOptions['extractionSchema'] = get_json_params(tool_parameters, 'extractionSchema') + + crawl_result = app.scrape_url(url=tool_parameters['url'], + pageOptions=pageOptions, + extractorOptions=extractorOptions) + + return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml index 29aa5991aa..598429de5e 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml @@ -3,7 +3,7 @@ identity: author: ahasasjeb label: en_US: Scrape - zh_Hans: 抓取 + zh_Hans: 单页面抓取 description: human: en_US: Extract data from a single URL. @@ -21,3 +21,160 @@ parameters: zh_Hans: 要抓取并提取数据的网站URL。 llm_description: The URL of the website that needs to be crawled. This is a required parameter. form: llm +############## Page Options ####################### + - name: headers + type: string + label: + en_US: headers + zh_Hans: 请求头 + human_description: + en_US: | + Headers to send with the request. Can be used to send cookies, user-agent, etc. Example: {"cookies": "testcookies"} + zh_Hans: | + 随请求发送的头部。可以用来发送cookies、用户代理等。示例:{"cookies": "testcookies"} + placeholder: + en_US: Please enter an object that can be serialized in JSON + zh_Hans: 请输入可以json序列化的对象 + form: form + - name: includeHtml + type: boolean + default: false + label: + en_US: include Html + zh_Hans: 包含HTML + human_description: + en_US: Include the HTML version of the content on page. Will output a html key in the response. + zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 + form: form + - name: includeRawHtml + type: boolean + default: false + label: + en_US: include Raw Html + zh_Hans: 包含原始HTML + human_description: + en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. + zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 + form: form + - name: onlyIncludeTags + type: string + label: + en_US: only Include Tags + zh_Hans: 仅抓取这些标签 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + form: form + - name: onlyMainContent + type: boolean + default: false + label: + en_US: only Main Content + zh_Hans: 仅抓取主要内容 + human_description: + en_US: Only return the main content of the page excluding headers, navs, footers, etc. + zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 + form: form + - name: removeTags + type: string + label: + en_US: remove Tags + zh_Hans: 要移除这些标签 + human_description: + en_US: | + Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + form: form + - name: replaceAllPathsWithAbsolutePaths + type: boolean + default: false + label: + en_US: All AbsolutePaths + zh_Hans: 使用绝对路径 + human_description: + en_US: Replace all relative paths with absolute paths for images and links. + zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。 + form: form + - name: screenshot + type: boolean + default: false + label: + en_US: screenshot + zh_Hans: 截图 + human_description: + en_US: Include a screenshot of the top of the page that you are scraping. + zh_Hans: 提供正在抓取的页面的顶部的截图。 + form: form + - name: waitFor + type: number + min: 0 + label: + en_US: wait For + zh_Hans: 等待时间 + human_description: + en_US: Wait x amount of milliseconds for the page to load to fetch content. + zh_Hans: 等待x毫秒以使页面加载并获取内容。 + form: form +############## Extractor Options ####################### + - name: mode + type: select + options: + - value: markdown + label: + en_US: markdown + - value: llm-extraction + label: + en_US: llm-extraction + - value: llm-extraction-from-raw-html + label: + en_US: llm-extraction-from-raw-html + - value: llm-extraction-from-markdown + label: + en_US: llm-extraction-from-markdown + label: + en_US: Extractor Mode + zh_Hans: 提取模式 + human_description: + en_US: | + The extraction mode to use. 'markdown': Returns the scraped markdown content, does not perform LLM extraction. 'llm-extraction': Extracts information from the cleaned and parsed content using LLM. + zh_Hans: 使用的提取模式。“markdown”:返回抓取的markdown内容,不执行LLM提取。“llm-extractioin”:使用LLM按Extractor Schema从内容中提取信息。 + form: form + - name: extractionPrompt + type: string + label: + en_US: Extractor Prompt + zh_Hans: 提取时的提示词 + human_description: + en_US: A prompt describing what information to extract from the page, applicable for LLM extraction modes. + zh_Hans: 当使用LLM提取模式时,用于给LLM描述提取规则。 + form: form + - name: extractionSchema + type: string + label: + en_US: Extractor Schema + zh_Hans: 提取时的结构 + placeholder: + en_US: Please enter an object that can be serialized in JSON + human_description: + en_US: | + The schema for the data to be extracted, required only for LLM extraction modes. Example: { + "type": "object", + "properties": {"company_mission": {"type": "string"}}, + "required": ["company_mission"] + } + zh_Hans: | + 当使用LLM提取模式时,使用该结构去提取,示例:{ + "type": "object", + "properties": {"company_mission": {"type": "string"}}, + "required": ["company_mission"] + } + form: form diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.py b/api/core/tools/provider/builtin/firecrawl/tools/search.py index 0b118aa5f1..e2b2ac6b4d 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/search.py @@ -1,5 +1,4 @@ -import json -from typing import Any, Union +from typing import Any from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp @@ -7,20 +6,23 @@ from core.tools.tool.builtin_tool import BuiltinTool class SearchTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], base_url=self.runtime.credentials['base_url']) - - crawl_result = app.search( + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + """ + the pageOptions and searchOptions comes from doc here: + https://docs.firecrawl.dev/api-reference/endpoint/search + """ + app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], + base_url=self.runtime.credentials['base_url']) + pageOptions = {} + pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) + pageOptions['fetchPageContent'] = tool_parameters.get('fetchPageContent', True) + pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) + pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) + searchOptions = {'limit': tool_parameters.get('limit')} + search_result = app.search( query=tool_parameters['keyword'], - wait=True + pageOptions=pageOptions, + searchOptions=searchOptions ) - if isinstance(crawl_result, dict): - result_message = json.dumps(crawl_result, ensure_ascii=False, indent=4) - else: - result_message = str(crawl_result) - - if not crawl_result: - return self.create_text_message("Search request failed.") - - return self.create_text_message(result_message) + return self.create_json_message(search_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.yaml b/api/core/tools/provider/builtin/firecrawl/tools/search.yaml index b1513c914e..29df0cfaaa 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/search.yaml @@ -21,3 +21,55 @@ parameters: zh_Hans: 输入关键词即可使用Firecrawl API进行搜索。 llm_description: Efficiently extract keywords from user text. form: llm +############## Page Options ####################### + - name: onlyMainContent + type: boolean + default: false + label: + en_US: only Main Content + zh_Hans: 仅抓取主要内容 + human_description: + en_US: Only return the main content of the page excluding headers, navs, footers, etc. + zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 + form: form + - name: fetchPageContent + type: boolean + default: true + label: + en_US: fetch Page Content + zh_Hans: 抓取页面内容 + human_description: + en_US: Fetch the content of each page. If false, defaults to a basic fast serp API. + zh_Hans: 获取每个页面的内容。如果为否,则使用基本的快速搜索结果页面API。 + form: form + - name: includeHtml + type: boolean + default: false + label: + en_US: include Html + zh_Hans: 包含HTML + human_description: + en_US: Include the HTML version of the content on page. Will output a html key in the response. + zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 + form: form + - name: includeRawHtml + type: boolean + default: false + label: + en_US: include Raw Html + zh_Hans: 包含原始HTML + human_description: + en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. + zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 + form: form +############## Search Options ####################### + - name: limit + type: number + min: 0 + label: + en_US: Maximum results + zh_Hans: 最大结果数量 + human_description: + en_US: Maximum number of results. Max is 20 during beta. + zh_Hans: 最大结果数量。在测试阶段,最大为20。 + form: form diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index 8409129833..cee46cee23 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -60,11 +60,13 @@ class JinaReaderTool(BuiltinTool): if tool_parameters.get('no_cache', False): headers['X-No-Cache'] = 'true' + max_retries = tool_parameters.get('max_retries', 3) response = ssrf_proxy.get( str(URL(self._jina_reader_endpoint + url)), headers=headers, params=request_params, timeout=(10, 60), + max_retries=max_retries ) if tool_parameters.get('summary', False): diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml index 072e7f0528..58ad6d8694 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml @@ -150,3 +150,17 @@ parameters: pt_BR: Habilitar resumo para a saída llm_description: enable summary form: form + - name: max_retries + type: number + required: false + default: 3 + label: + en_US: Retry + zh_Hans: 重试 + pt_BR: Repetir + human_description: + en_US: Number of times to retry the request if it fails + zh_Hans: 请求失败时重试的次数 + pt_BR: Número de vezes para repetir a solicitação se falhar + llm_description: Number of times to retry the request if it fails + form: form diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.py b/api/core/tools/provider/builtin/jina/tools/jina_search.py index e6bc08147f..d4a81cd096 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.py @@ -40,10 +40,12 @@ class JinaSearchTool(BuiltinTool): if tool_parameters.get('no_cache', False): headers['X-No-Cache'] = 'true' + max_retries = tool_parameters.get('max_retries', 3) response = ssrf_proxy.get( str(URL(self._jina_search_endpoint + query)), headers=headers, - timeout=(10, 60) + timeout=(10, 60), + max_retries=max_retries ) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.yaml b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml index da0a300c6c..2bc70e1be1 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.yaml +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml @@ -91,3 +91,17 @@ parameters: pt_BR: Ignorar o cache llm_description: bypass the cache form: form + - name: max_retries + type: number + required: false + default: 3 + label: + en_US: Retry + zh_Hans: 重试 + pt_BR: Repetir + human_description: + en_US: Number of times to retry the request if it fails + zh_Hans: 请求失败时重试的次数 + pt_BR: Número de vezes para repetir a solicitação se falhar + llm_description: Number of times to retry the request if it fails + form: form diff --git a/api/core/tools/provider/builtin/spider/spiderApp.py b/api/core/tools/provider/builtin/spider/spiderApp.py index 82c0df19ca..f0ed64867a 100644 --- a/api/core/tools/provider/builtin/spider/spiderApp.py +++ b/api/core/tools/provider/builtin/spider/spiderApp.py @@ -116,6 +116,7 @@ class Spider: :param params: Optional dictionary of additional parameters for the scrape request. :return: JSON response containing the scraping results. """ + params = params or {} # Add { "return_format": "markdown" } to the params if not already present if "return_format" not in params: @@ -143,6 +144,7 @@ class Spider: :param stream: Boolean indicating if the response should be streamed. Defaults to False. :return: JSON response or the raw response stream if streaming enabled. """ + params = params or {} # Add { "return_format": "markdown" } to the params if not already present if "return_format" not in params: diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 47e33b70c9..bcf41c90ed 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -27,7 +27,7 @@ class BuiltinToolProviderController(ToolProviderController): provider = self.__class__.__module__.split('.')[-1] yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') try: - provider_yaml = load_yaml_file(yaml_path) + provider_yaml = load_yaml_file(yaml_path, ignore_error=False) except Exception as e: raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}') @@ -58,7 +58,7 @@ class BuiltinToolProviderController(ToolProviderController): for tool_file in tool_files: # get tool name tool_name = tool_file.split(".")[0] - tool = load_yaml_file(path.join(tool_path, tool_file)) + tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False) # get tool class, import the module assistant_tool_class = load_single_subclass_from_source( diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index e52082541a..a461328ae6 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -10,12 +10,12 @@ import unicodedata from contextlib import contextmanager from urllib.parse import unquote +import chardet import cloudscraper -import requests from bs4 import BeautifulSoup, CData, Comment, NavigableString -from newspaper import Article from regex import regex +from core.helper import ssrf_proxy from core.rag.extractor import extract_processor from core.rag.extractor.extract_processor import ExtractProcessor @@ -45,7 +45,7 @@ def get_url(url: str, user_agent: str = None) -> str: main_content_type = None supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] - response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) + response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) if response.status_code == 200: # check content-type @@ -67,18 +67,30 @@ def get_url(url: str, user_agent: str = None) -> str: if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: return ExtractProcessor.load_from_url(url, return_text=True) - response = requests.get(url, headers=headers, allow_redirects=True, timeout=(120, 300)) + response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: scraper = cloudscraper.create_scraper() - response = scraper.get(url, headers=headers, allow_redirects=True, timeout=(120, 300)) + scraper.perform_request = ssrf_proxy.make_request + response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) if response.status_code != 200: return "URL returned status code {}.".format(response.status_code) - a = extract_using_readabilipy(response.text) + # Detect encoding using chardet + detected_encoding = chardet.detect(response.content) + encoding = detected_encoding['encoding'] + if encoding: + try: + content = response.content.decode(encoding) + except (UnicodeDecodeError, TypeError): + content = response.text + else: + content = response.text + + a = extract_using_readabilipy(content) if not a['plain_text'] or not a['plain_text'].strip(): - return get_url_from_newspaper3k(url) + return '' res = FULL_TEMPLATE.format( title=a['title'], @@ -91,23 +103,6 @@ def get_url(url: str, user_agent: str = None) -> str: return res -def get_url_from_newspaper3k(url: str) -> str: - - a = Article(url) - a.download() - a.parse() - - res = FULL_TEMPLATE.format( - title=a.title, - authors=a.authors, - publish_date=a.publish_date, - top_image=a.top_image, - text=a.text, - ) - - return res - - def extract_using_readabilipy(html): with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html: f_html.write(html) diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 3526647b4f..21155a6960 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -1,35 +1,32 @@ import logging -import os +from typing import Any import yaml from yaml import YAMLError logger = logging.getLogger(__name__) -def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict: + +def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any: """ - Safe loading a YAML file to a dict + Safe loading a YAML file :param file_path: the path of the YAML file :param ignore_error: - if True, return empty dict if error occurs and the error will be logged in warning level + if True, return default_value if error occurs and the error will be logged in debug level if False, raise error if error occurs - :return: a dict of the YAML content + :param default_value: the value returned when errors ignored + :return: an object of the YAML content """ try: - if not file_path or not os.path.exists(file_path): - raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found') - - with open(file_path, encoding='utf-8') as file: + with open(file_path, encoding='utf-8') as yaml_file: try: - return yaml.safe_load(file) + yaml_content = yaml.safe_load(yaml_file) + return yaml_content if yaml_content else default_value except Exception as e: raise YAMLError(f'Failed to load YAML file {file_path}: {e}') - except FileNotFoundError as e: - logger.debug(f'Failed to load YAML file {file_path}: {e}') - return {} except Exception as e: if ignore_error: - logger.warning(f'Failed to load YAML file {file_path}: {e}') - return {} + logger.debug(f'Failed to load YAML file {file_path}: {e}') + return default_value else: raise e diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 957e744aac..270e104e37 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -5,7 +5,7 @@ from typing import Any, Union from pydantic import BaseModel, Field, model_validator from typing_extensions import deprecated -from core.app.segments import Variable, factory +from core.app.segments import Segment, Variable, factory from core.file.file_obj import FileVar from core.workflow.entities.node_entities import SystemVariable @@ -21,7 +21,7 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: dict[str, dict[int, Variable]] = Field( + variable_dictionary: dict[str, dict[int, Segment]] = Field( description='Variables mapping', default=defaultdict(dict) ) @@ -76,15 +76,15 @@ class VariablePool(BaseModel): if value is None: return - if not isinstance(value, Variable): - v = factory.build_anonymous_variable(value) - else: + if isinstance(value, Segment): v = value + else: + v = factory.build_segment(value) hash_key = hash(tuple(selector[1:])) self.variable_dictionary[selector[0]][hash_key] = v - def get(self, selector: Sequence[str], /) -> Variable | None: + def get(self, selector: Sequence[str], /) -> Segment | None: """ Retrieves the value from the variable pool based on the given selector. diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index d5623bc81c..55902069c2 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -65,7 +65,7 @@ class BaseNode(ABC): yield from result @classmethod - def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]: + def extract_variable_selector_to_variable_mapping(cls, config: dict): """ Extract variable selector to variable mapping :param config: node config @@ -75,14 +75,13 @@ class BaseNode(ABC): return cls._extract_variable_selector_to_variable_mapping(node_data) @classmethod - @abstractmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ - raise NotImplementedError + return {} @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 40792a5fff..1179869b03 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -260,6 +260,7 @@ class IterationNode(BaseNode): }, error=str(e), ) + yield RunCompletedEvent( run_result=NodeRunResult( @@ -271,6 +272,61 @@ class IterationNode(BaseNode): # remove iteration variable (item, index) from variable pool after iteration run completed variable_pool.remove([self.node_id, 'index']) variable_pool.remove([self.node_id, 'item']) + + def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState): + """ + Set current iteration variable. + :variable_pool: variable pool + """ + node_data = cast(IterationNodeData, self.node_data) + + variable_pool.add((self.node_id, 'index'), state.index) + # get the iterator value + iterator = variable_pool.get_any(node_data.iterator_selector) + + if iterator is None or not isinstance(iterator, list): + return + + if state.index < len(iterator): + variable_pool.add((self.node_id, 'item'), iterator[state.index]) + + def _next_iteration(self, variable_pool: VariablePool, state: IterationState): + """ + Move to next iteration. + :param variable_pool: variable pool + """ + state.index += 1 + self._set_current_iteration_variable(variable_pool, state) + + def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState): + """ + Check if iteration limit is reached. + :return: True if iteration limit is reached, False otherwise + """ + node_data = cast(IterationNodeData, self.node_data) + iterator = variable_pool.get_any(node_data.iterator_selector) + + if iterator is None or not isinstance(iterator, list): + return True + + return state.index >= len(iterator) + + def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState): + """ + Resolve current output. + :param variable_pool: variable pool + """ + output_selector = cast(IterationNodeData, self.node_data).output_selector + output = variable_pool.get_any(output_selector) + # clear the output for this iteration + variable_pool.remove([self.node_id] + output_selector[1:]) + state.current_output = output + if output is not None: + # NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration). + if isinstance(output, list): + state.outputs.extend(output) + else: + state.outputs.append(output) @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]: diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 5a9a4a9009..5758b895f3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -46,8 +46,8 @@ class MultipleRetrievalConfig(BaseModel): score_threshold: Optional[float] = None reranking_mode: str = 'reranking_model' reranking_enable: bool = True - reranking_model: RerankingModelConfig - weights: WeightedScoreConfig + reranking_model: Optional[RerankingModelConfig] = None + weights: Optional[WeightedScoreConfig] = None class ModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 00a46b8a60..196aa169a0 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -139,8 +139,8 @@ class KnowledgeRetrievalNode(BaseNode): elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model': reranking_model = { - 'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model['provider'], - 'reranking_model_name': node_data.multiple_retrieval_config.reranking_model['name'] + 'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model.provider, + 'reranking_model_name': node_data.multiple_retrieval_config.reranking_model.model } weights = None elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score': diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index b946eb5816..1141417c55 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -125,11 +125,16 @@ class ToolNode(BaseNode): ] else: tool_input = node_data.tool_parameters[parameter_name] - segment_group = parser.convert_template( - template=str(tool_input.value), - variable_pool=variable_pool, - ) - result[parameter_name] = segment_group.log if for_log else segment_group.text + if tool_input.type == 'variable': + # TODO: check if the variable exists in the variable pool + parameter_value = variable_pool.get(tool_input.value).value + else: + segment_group = parser.convert_template( + template=str(tool_input.value), + variable_pool=variable_pool, + ) + parameter_value = segment_group.log if for_log else segment_group.text + result[parameter_name] = parameter_value return result diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/extensions/storage/tencent_storage.py b/api/extensions/storage/tencent_storage.py index 6d9fb80f5e..e2c1ca55e3 100644 --- a/api/extensions/storage/tencent_storage.py +++ b/api/extensions/storage/tencent_storage.py @@ -32,8 +32,7 @@ class TencentStorage(BaseStorage): def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - while chunk := response['Body'].get_stream(chunk_size=4096): - yield chunk + yield from response['Body'].get_stream(chunk_size=4096) return generate() diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index c98c332021..ff33a97ff2 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,10 +1,12 @@ from flask_restful import fields -from core.app.segments import SecretVariable, Variable +from core.app.segments import SecretVariable, SegmentType, Variable from core.helper import encrypter from fields.member_fields import simple_account_fields from libs.helper import TimestampField +ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET) + class EnvironmentVariableField(fields.Raw): def format(self, value): @@ -16,14 +18,18 @@ class EnvironmentVariableField(fields.Raw): 'value': encrypter.obfuscated_token(value.value), 'value_type': value.value_type.value, } - elif isinstance(value, Variable): + if isinstance(value, Variable): return { 'id': value.id, 'name': value.name, 'value': value.value, 'value_type': value.value_type.value, } - return value + if isinstance(value, dict): + value_type = value.get('value_type') + if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: + raise ValueError(f'Unsupported environment variable value type: {value_type}') + return value environment_variable_fields = { diff --git a/api/libs/passport.py b/api/libs/passport.py index 530709f18c..34bdc55997 100644 --- a/api/libs/passport.py +++ b/api/libs/passport.py @@ -1,15 +1,16 @@ import jwt -from flask import current_app from werkzeug.exceptions import Unauthorized +from configs import dify_config + class PassportService: def __init__(self): - self.sk = current_app.config.get('SECRET_KEY') - + self.sk = dify_config.SECRET_KEY + def issue(self, payload): return jwt.encode(payload, self.sk, algorithm='HS256') - + def verify(self, token): try: return jwt.decode(token, self.sk, algorithms=['HS256']) diff --git a/api/migrations/versions/eeb2e349e6ac_increase_max_model_name_length.py b/api/migrations/versions/eeb2e349e6ac_increase_max_model_name_length.py new file mode 100644 index 0000000000..434531b6c8 --- /dev/null +++ b/api/migrations/versions/eeb2e349e6ac_increase_max_model_name_length.py @@ -0,0 +1,53 @@ +"""increase max model_name length + +Revision ID: eeb2e349e6ac +Revises: 53bf8af60645 +Create Date: 2024-07-26 12:02:00.750358 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = 'eeb2e349e6ac' +down_revision = '53bf8af60645' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: + batch_op.alter_column('model_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.alter_column('model_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False, + existing_server_default=sa.text("'text-embedding-ada-002'::character varying")) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.alter_column('model_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False, + existing_server_default=sa.text("'text-embedding-ada-002'::character varying")) + + with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: + batch_op.alter_column('model_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index ca1c4583a6..40f9f4cf83 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -635,7 +635,7 @@ class Embedding(db.Model): ) id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) - model_name = db.Column(db.String(40), nullable=False, + model_name = db.Column(db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) @@ -660,7 +660,7 @@ class DatasetCollectionBinding(db.Model): id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) provider_name = db.Column(db.String(40), nullable=False) - model_name = db.Column(db.String(40), nullable=False) + model_name = db.Column(db.String(255), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) collection_name = db.Column(db.String(64), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/poetry.lock b/api/poetry.lock index 2a277dac2d..abde108a7a 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -448,63 +448,6 @@ doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphin test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] trio = ["trio (>=0.23)"] -[[package]] -name = "argon2-cffi" -version = "23.1.0" -description = "Argon2 for Python" -optional = false -python-versions = ">=3.7" -files = [ - {file = "argon2_cffi-23.1.0-py3-none-any.whl", hash = "sha256:c670642b78ba29641818ab2e68bd4e6a78ba53b7eff7b4c3815ae16abf91c7ea"}, - {file = "argon2_cffi-23.1.0.tar.gz", hash = "sha256:879c3e79a2729ce768ebb7d36d4609e3a78a4ca2ec3a9f12286ca057e3d0db08"}, -] - -[package.dependencies] -argon2-cffi-bindings = "*" - -[package.extras] -dev = ["argon2-cffi[tests,typing]", "tox (>4)"] -docs = ["furo", "myst-parser", "sphinx", "sphinx-copybutton", "sphinx-notfound-page"] -tests = ["hypothesis", "pytest"] -typing = ["mypy"] - -[[package]] -name = "argon2-cffi-bindings" -version = "21.2.0" -description = "Low-level CFFI bindings for Argon2" -optional = false -python-versions = ">=3.6" -files = [ - {file = "argon2-cffi-bindings-21.2.0.tar.gz", hash = "sha256:bb89ceffa6c791807d1305ceb77dbfacc5aa499891d2c55661c6459651fc39e3"}, - {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ccb949252cb2ab3a08c02024acb77cfb179492d5701c7cbdbfd776124d4d2367"}, - {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9524464572e12979364b7d600abf96181d3541da11e23ddf565a32e70bd4dc0d"}, - {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b746dba803a79238e925d9046a63aa26bf86ab2a2fe74ce6b009a1c3f5c8f2ae"}, - {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58ed19212051f49a523abb1dbe954337dc82d947fb6e5a0da60f7c8471a8476c"}, - {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:bd46088725ef7f58b5a1ef7ca06647ebaf0eb4baff7d1d0d177c6cc8744abd86"}, - {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_i686.whl", hash = "sha256:8cd69c07dd875537a824deec19f978e0f2078fdda07fd5c42ac29668dda5f40f"}, - {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:f1152ac548bd5b8bcecfb0b0371f082037e47128653df2e8ba6e914d384f3c3e"}, - {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win32.whl", hash = "sha256:603ca0aba86b1349b147cab91ae970c63118a0f30444d4bc80355937c950c082"}, - {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win_amd64.whl", hash = "sha256:b2ef1c30440dbbcba7a5dc3e319408b59676e2e039e2ae11a8775ecf482b192f"}, - {file = "argon2_cffi_bindings-21.2.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e415e3f62c8d124ee16018e491a009937f8cf7ebf5eb430ffc5de21b900dad93"}, - {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3e385d1c39c520c08b53d63300c3ecc28622f076f4c2b0e6d7e796e9f6502194"}, - {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c3e3cc67fdb7d82c4718f19b4e7a87123caf8a93fde7e23cf66ac0337d3cb3f"}, - {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a22ad9800121b71099d0fb0a65323810a15f2e292f2ba450810a7316e128ee5"}, - {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f9f8b450ed0547e3d473fdc8612083fd08dd2120d6ac8f73828df9b7d45bb351"}, - {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:93f9bf70084f97245ba10ee36575f0c3f1e7d7724d67d8e5b08e61787c320ed7"}, - {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3b9ef65804859d335dc6b31582cad2c5166f0c3e7975f324d9ffaa34ee7e6583"}, - {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4966ef5848d820776f5f562a7d45fdd70c2f330c961d0d745b784034bd9f48d"}, - {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ef543a89dee4db46a1a6e206cd015360e5a75822f76df533845c3cbaf72670"}, - {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed2937d286e2ad0cc79a7087d3c272832865f779430e0cc2b4f3718d3159b0cb"}, - {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5e00316dabdaea0b2dd82d141cc66889ced0cdcbfa599e8b471cf22c620c329a"}, -] - -[package.dependencies] -cffi = ">=1.0.1" - -[package.extras] -dev = ["cogapp", "pre-commit", "pytest", "wheel"] -tests = ["pytest"] - [[package]] name = "arxiv" version = "2.1.0" @@ -4616,22 +4559,20 @@ files = [ ] [[package]] -name = "minio" -version = "7.2.7" -description = "MinIO Python SDK for Amazon S3 Compatible Cloud Storage" +name = "milvus-lite" +version = "2.4.8" +description = "A lightweight version of Milvus wrapped with Python." optional = false -python-versions = "*" +python-versions = ">=3.7" files = [ - {file = "minio-7.2.7-py3-none-any.whl", hash = "sha256:59d1f255d852fe7104018db75b3bebbd987e538690e680f7c5de835e422de837"}, - {file = "minio-7.2.7.tar.gz", hash = "sha256:473d5d53d79f340f3cd632054d0c82d2f93177ce1af2eac34a235bea55708d98"}, + {file = "milvus_lite-2.4.8-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:b7e90b34b214884cd44cdc112ab243d4cb197b775498355e2437b6cafea025fe"}, + {file = "milvus_lite-2.4.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:519dfc62709d8f642d98a1c5b1dcde7080d107e6e312d677fef5a3412a40ac08"}, + {file = "milvus_lite-2.4.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b21f36d24cbb0e920b4faad607019bb28c1b2c88b4d04680ac8c7697a4ae8a4d"}, + {file = "milvus_lite-2.4.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:08332a2b9abfe7c4e1d7926068937e46f8fb81f2707928b7bc02c9dc99cebe41"}, ] [package.dependencies] -argon2-cffi = "*" -certifi = "*" -pycryptodome = "*" -typing-extensions = "*" -urllib3 = "*" +tqdm = "*" [[package]] name = "mmh3" @@ -6078,6 +6019,19 @@ files = [ {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] @@ -6374,24 +6328,29 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] [[package]] name = "pymilvus" -version = "2.3.1" +version = "2.4.4" description = "Python Sdk for Milvus" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pymilvus-2.3.1-py3-none-any.whl", hash = "sha256:ce65e1de8700f33bd9aade20f013291629702e25b05726773208f1f0b22548ff"}, - {file = "pymilvus-2.3.1.tar.gz", hash = "sha256:d460f6204d7deb2cff93716bd65670c1b440694b77701fb0ab0ead791aa582c6"}, + {file = "pymilvus-2.4.4-py3-none-any.whl", hash = "sha256:073b76bc36f6f4e70f0f0a0023a53324f0ba8ef9a60883f87cd30a44b6c6f2b5"}, + {file = "pymilvus-2.4.4.tar.gz", hash = "sha256:50c53eb103e034fbffe936fe942751ea3dbd2452e18cf79acc52360ed4987fb7"}, ] [package.dependencies] environs = "<=9.5.0" -grpcio = ">=1.49.1,<=1.58.0" -minio = "*" +grpcio = ">=1.49.1,<=1.63.0" +milvus-lite = {version = ">=2.4.0,<2.5.0", markers = "sys_platform != \"win32\""} pandas = ">=1.2.4" protobuf = ">=3.20.0" -requests = "*" +setuptools = ">=67" ujson = ">=2.0.0" +[package.extras] +bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "requests"] +dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"] +model = ["milvus-model (>=0.1.0)"] + [[package]] name = "pymysql" version = "1.1.1" @@ -9543,4 +9502,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "6b7d8b1333ae9c71ba2e1c5800eecf1535ed3945cd55ebb1e253b7a29ba09559" +content-hash = "a8b61d74d9322302b7447b6f8728ad606abc160202a8a122a05a8ef3cec7055b" diff --git a/api/pyproject.toml b/api/pyproject.toml index 7be3c7af64..25778f323d 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -177,6 +177,7 @@ xinference-client = "0.9.4" yarl = "~1.9.4" zhipuai = "1.0.7" rank-bm25 = "~0.2.2" +openpyxl = "^3.1.5" ############################################################ # Tool dependencies required by tool implementations ############################################################ @@ -205,7 +206,7 @@ chromadb = "0.5.1" oracledb = "~2.2.1" pgvecto-rs = "0.1.4" pgvector = "0.2.5" -pymilvus = "2.3.1" +pymilvus = "~2.4.4" pymysql = "1.1.1" tcvectordb = "1.3.2" tidb-vector = "0.0.9" @@ -215,18 +216,6 @@ alibabacloud_gpdb20160503 = "~3.8.0" alibabacloud_tea_openapi = "~0.3.9" clickhouse-connect = "~0.7.16" -############################################################ -# Transparent dependencies required by main dependencies -# for pinning versions -############################################################ - -[tool.poetry.group.transparent.dependencies] -kaleido = "0.2.1" -lxml = "5.1.0" -sympy = "1.12" -tenacity = "~8.3.0" -xlrd = "~2.0.1" - ############################################################ # Dev dependencies for running tests ############################################################ diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py new file mode 100644 index 0000000000..7ae6c0e456 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py @@ -0,0 +1,104 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.hunyuan.text_embedding.text_embedding import HunyuanTextEmbeddingModel + + +def test_validate_credentials(): + model = HunyuanTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='hunyuan-embedding', + credentials={ + 'secret_id': 'invalid_key', + 'secret_key': 'invalid_key' + } + ) + + model.validate_credentials( + model='hunyuan-embedding', + credentials={ + 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), + 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + } + ) + + +def test_invoke_model(): + model = HunyuanTextEmbeddingModel() + + result = model.invoke( + model='hunyuan-embedding', + credentials={ + 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), + 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + }, + texts=[ + "hello", + "world" + ], + user="abc-123" + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 6 + +def test_get_num_tokens(): + model = HunyuanTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model='hunyuan-embedding', + credentials={ + 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), + 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + }, + texts=[ + "hello", + "world" + ] + ) + + assert num_tokens == 2 + +def test_max_chunks(): + model = HunyuanTextEmbeddingModel() + + result = model.invoke( + model='hunyuan-embedding', + credentials={ + 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), + 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + }, + texts=[ + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + ] + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 22 \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/siliconflow/__init__.py b/api/tests/integration_tests/model_runtime/siliconflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py new file mode 100644 index 0000000000..befdd82352 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py @@ -0,0 +1,106 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.llm.llm import SiliconflowLargeLanguageModel + + +def test_validate_credentials(): + model = SiliconflowLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='deepseek-ai/DeepSeek-V2-Chat', + credentials={ + 'api_key': 'invalid_key' + } + ) + + model.validate_credentials( + model='deepseek-ai/DeepSeek-V2-Chat', + credentials={ + 'api_key': os.environ.get('API_KEY') + } + ) + + +def test_invoke_model(): + model = SiliconflowLargeLanguageModel() + + response = model.invoke( + model='deepseek-ai/DeepSeek-V2-Chat', + credentials={ + 'api_key': os.environ.get('API_KEY') + }, + prompt_messages=[ + UserPromptMessage( + content='Who are you?' + ) + ], + model_parameters={ + 'temperature': 0.5, + 'max_tokens': 10 + }, + stop=['How'], + stream=False, + user="abc-123" + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = SiliconflowLargeLanguageModel() + + response = model.invoke( + model='deepseek-ai/DeepSeek-V2-Chat', + credentials={ + 'api_key': os.environ.get('API_KEY') + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.5, + 'max_tokens': 100, + 'seed': 1234 + }, + stream=True, + user="abc-123" + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = SiliconflowLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model='deepseek-ai/DeepSeek-V2-Chat', + credentials={ + 'api_key': os.environ.get('API_KEY') + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ] + ) + + assert num_tokens == 12 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py new file mode 100644 index 0000000000..7b9211a5db --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py @@ -0,0 +1,21 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.siliconflow import SiliconflowProvider + + +def test_validate_provider_credentials(): + provider = SiliconflowProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials( + credentials={} + ) + + provider.validate_provider_credentials( + credentials={ + 'api_key': os.environ.get('API_KEY') + } + ) diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py new file mode 100644 index 0000000000..85321ee374 --- /dev/null +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -0,0 +1,307 @@ +from uuid import uuid4 + +import pytest + +from core.app.segments import ( + ArrayFileVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FileVariable, + FloatVariable, + IntegerVariable, + NoneSegment, + ObjectSegment, + SecretVariable, + StringVariable, + factory, +) + + +def test_string_variable(): + test_data = {'value_type': 'string', 'name': 'test_text', 'value': 'Hello, World!'} + result = factory.build_variable_from_mapping(test_data) + assert isinstance(result, StringVariable) + + +def test_integer_variable(): + test_data = {'value_type': 'number', 'name': 'test_int', 'value': 42} + result = factory.build_variable_from_mapping(test_data) + assert isinstance(result, IntegerVariable) + + +def test_float_variable(): + test_data = {'value_type': 'number', 'name': 'test_float', 'value': 3.14} + result = factory.build_variable_from_mapping(test_data) + assert isinstance(result, FloatVariable) + + +def test_secret_variable(): + test_data = {'value_type': 'secret', 'name': 'test_secret', 'value': 'secret_value'} + result = factory.build_variable_from_mapping(test_data) + assert isinstance(result, SecretVariable) + + +def test_invalid_value_type(): + test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'} + with pytest.raises(ValueError): + factory.build_variable_from_mapping(test_data) + + +def test_build_a_blank_string(): + result = factory.build_variable_from_mapping( + { + 'value_type': 'string', + 'name': 'blank', + 'value': '', + } + ) + assert isinstance(result, StringVariable) + assert result.value == '' + + +def test_build_a_object_variable_with_none_value(): + var = factory.build_segment( + { + 'key1': None, + } + ) + assert isinstance(var, ObjectSegment) + assert isinstance(var.value['key1'], NoneSegment) + + +def test_object_variable(): + mapping = { + 'id': str(uuid4()), + 'value_type': 'object', + 'name': 'test_object', + 'description': 'Description of the variable.', + 'value': { + 'key1': { + 'id': str(uuid4()), + 'value_type': 'string', + 'name': 'text', + 'value': 'text', + 'description': 'Description of the variable.', + }, + 'key2': { + 'id': str(uuid4()), + 'value_type': 'number', + 'name': 'number', + 'value': 1, + 'description': 'Description of the variable.', + }, + }, + } + variable = factory.build_variable_from_mapping(mapping) + assert isinstance(variable, ObjectSegment) + assert isinstance(variable.value['key1'], StringVariable) + assert isinstance(variable.value['key2'], IntegerVariable) + + +def test_array_string_variable(): + mapping = { + 'id': str(uuid4()), + 'value_type': 'array[string]', + 'name': 'test_array', + 'description': 'Description of the variable.', + 'value': [ + { + 'id': str(uuid4()), + 'value_type': 'string', + 'name': 'text', + 'value': 'text', + 'description': 'Description of the variable.', + }, + { + 'id': str(uuid4()), + 'value_type': 'string', + 'name': 'text', + 'value': 'text', + 'description': 'Description of the variable.', + }, + ], + } + variable = factory.build_variable_from_mapping(mapping) + assert isinstance(variable, ArrayStringVariable) + assert isinstance(variable.value[0], StringVariable) + assert isinstance(variable.value[1], StringVariable) + + +def test_array_number_variable(): + mapping = { + 'id': str(uuid4()), + 'value_type': 'array[number]', + 'name': 'test_array', + 'description': 'Description of the variable.', + 'value': [ + { + 'id': str(uuid4()), + 'value_type': 'number', + 'name': 'number', + 'value': 1, + 'description': 'Description of the variable.', + }, + { + 'id': str(uuid4()), + 'value_type': 'number', + 'name': 'number', + 'value': 2.0, + 'description': 'Description of the variable.', + }, + ], + } + variable = factory.build_variable_from_mapping(mapping) + assert isinstance(variable, ArrayNumberVariable) + assert isinstance(variable.value[0], IntegerVariable) + assert isinstance(variable.value[1], FloatVariable) + + +def test_array_object_variable(): + mapping = { + 'id': str(uuid4()), + 'value_type': 'array[object]', + 'name': 'test_array', + 'description': 'Description of the variable.', + 'value': [ + { + 'id': str(uuid4()), + 'value_type': 'object', + 'name': 'object', + 'description': 'Description of the variable.', + 'value': { + 'key1': { + 'id': str(uuid4()), + 'value_type': 'string', + 'name': 'text', + 'value': 'text', + 'description': 'Description of the variable.', + }, + 'key2': { + 'id': str(uuid4()), + 'value_type': 'number', + 'name': 'number', + 'value': 1, + 'description': 'Description of the variable.', + }, + }, + }, + { + 'id': str(uuid4()), + 'value_type': 'object', + 'name': 'object', + 'description': 'Description of the variable.', + 'value': { + 'key1': { + 'id': str(uuid4()), + 'value_type': 'string', + 'name': 'text', + 'value': 'text', + 'description': 'Description of the variable.', + }, + 'key2': { + 'id': str(uuid4()), + 'value_type': 'number', + 'name': 'number', + 'value': 1, + 'description': 'Description of the variable.', + }, + }, + }, + ], + } + variable = factory.build_variable_from_mapping(mapping) + assert isinstance(variable, ArrayObjectVariable) + assert isinstance(variable.value[0], ObjectSegment) + assert isinstance(variable.value[1], ObjectSegment) + assert isinstance(variable.value[0].value['key1'], StringVariable) + assert isinstance(variable.value[0].value['key2'], IntegerVariable) + assert isinstance(variable.value[1].value['key1'], StringVariable) + assert isinstance(variable.value[1].value['key2'], IntegerVariable) + + +def test_file_variable(): + mapping = { + 'id': str(uuid4()), + 'value_type': 'file', + 'name': 'test_file', + 'description': 'Description of the variable.', + 'value': { + 'id': str(uuid4()), + 'tenant_id': 'tenant_id', + 'type': 'image', + 'transfer_method': 'local_file', + 'url': 'url', + 'related_id': 'related_id', + 'extra_config': { + 'image_config': { + 'width': 100, + 'height': 100, + }, + }, + 'filename': 'filename', + 'extension': 'extension', + 'mime_type': 'mime_type', + }, + } + variable = factory.build_variable_from_mapping(mapping) + assert isinstance(variable, FileVariable) + + +def test_array_file_variable(): + mapping = { + 'id': str(uuid4()), + 'value_type': 'array[file]', + 'name': 'test_array_file', + 'description': 'Description of the variable.', + 'value': [ + { + 'id': str(uuid4()), + 'name': 'file', + 'value_type': 'file', + 'value': { + 'id': str(uuid4()), + 'tenant_id': 'tenant_id', + 'type': 'image', + 'transfer_method': 'local_file', + 'url': 'url', + 'related_id': 'related_id', + 'extra_config': { + 'image_config': { + 'width': 100, + 'height': 100, + }, + }, + 'filename': 'filename', + 'extension': 'extension', + 'mime_type': 'mime_type', + }, + }, + { + 'id': str(uuid4()), + 'name': 'file', + 'value_type': 'file', + 'value': { + 'id': str(uuid4()), + 'tenant_id': 'tenant_id', + 'type': 'image', + 'transfer_method': 'local_file', + 'url': 'url', + 'related_id': 'related_id', + 'extra_config': { + 'image_config': { + 'width': 100, + 'height': 100, + }, + }, + 'filename': 'filename', + 'extension': 'extension', + 'mime_type': 'mime_type', + }, + }, + ], + } + variable = factory.build_variable_from_mapping(mapping) + assert isinstance(variable, ArrayFileVariable) + assert isinstance(variable.value[0], FileVariable) + assert isinstance(variable.value[1], FileVariable) diff --git a/api/tests/unit_tests/app/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py similarity index 93% rename from api/tests/unit_tests/app/test_segment.py rename to api/tests/unit_tests/core/app/segments/test_segment.py index 7ef37ff646..414404b7d0 100644 --- a/api/tests/unit_tests/app/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -1,4 +1,4 @@ -from core.app.segments import SecretVariable, parser +from core.app.segments import SecretVariable, StringSegment, parser from core.helper import encrypter from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool @@ -51,3 +51,4 @@ def test_convert_variable_to_segment_group(): segments_group = parser.convert_template(template=template, variable_pool=variable_pool) assert segments_group.text == 'fake-user-id' assert segments_group.log == 'fake-user-id' + assert segments_group.value == [StringSegment(value='fake-user-id')] diff --git a/api/tests/unit_tests/app/test_variables.py b/api/tests/unit_tests/core/app/segments/test_variables.py similarity index 54% rename from api/tests/unit_tests/app/test_variables.py rename to api/tests/unit_tests/core/app/segments/test_variables.py index 40872c8d53..e3f513971a 100644 --- a/api/tests/unit_tests/app/test_variables.py +++ b/api/tests/unit_tests/core/app/segments/test_variables.py @@ -2,48 +2,16 @@ import pytest from pydantic import ValidationError from core.app.segments import ( - ArrayVariable, + ArrayAnyVariable, FloatVariable, IntegerVariable, - NoneVariable, ObjectVariable, SecretVariable, SegmentType, StringVariable, - factory, ) -def test_string_variable(): - test_data = {'value_type': 'string', 'name': 'test_text', 'value': 'Hello, World!'} - result = factory.build_variable_from_mapping(test_data) - assert isinstance(result, StringVariable) - - -def test_integer_variable(): - test_data = {'value_type': 'number', 'name': 'test_int', 'value': 42} - result = factory.build_variable_from_mapping(test_data) - assert isinstance(result, IntegerVariable) - - -def test_float_variable(): - test_data = {'value_type': 'number', 'name': 'test_float', 'value': 3.14} - result = factory.build_variable_from_mapping(test_data) - assert isinstance(result, FloatVariable) - - -def test_secret_variable(): - test_data = {'value_type': 'secret', 'name': 'test_secret', 'value': 'secret_value'} - result = factory.build_variable_from_mapping(test_data) - assert isinstance(result, SecretVariable) - - -def test_invalid_value_type(): - test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'} - with pytest.raises(ValueError): - factory.build_variable_from_mapping(test_data) - - def test_frozen_variables(): var = StringVariable(name='text', value='text') with pytest.raises(ValidationError): @@ -64,34 +32,22 @@ def test_frozen_variables(): def test_variable_value_type_immutable(): with pytest.raises(ValidationError): - StringVariable(value_type=SegmentType.ARRAY, name='text', value='text') + StringVariable(value_type=SegmentType.ARRAY_ANY, name='text', value='text') with pytest.raises(ValidationError): StringVariable.model_validate({'value_type': 'not text', 'name': 'text', 'value': 'text'}) var = IntegerVariable(name='integer', value=42) with pytest.raises(ValidationError): - IntegerVariable(value_type=SegmentType.ARRAY, name=var.name, value=var.value) + IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) var = FloatVariable(name='float', value=3.14) with pytest.raises(ValidationError): - FloatVariable(value_type=SegmentType.ARRAY, name=var.name, value=var.value) + FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) var = SecretVariable(name='secret', value='secret_value') with pytest.raises(ValidationError): - SecretVariable(value_type=SegmentType.ARRAY, name=var.name, value=var.value) - - -def test_build_a_blank_string(): - result = factory.build_variable_from_mapping( - { - 'value_type': 'string', - 'name': 'blank', - 'value': '', - } - ) - assert isinstance(result, StringVariable) - assert result.value == '' + SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) def test_object_variable_to_object(): @@ -104,7 +60,7 @@ def test_object_variable_to_object(): 'key2': StringVariable(name='key2', value='value2'), }, ), - 'key2': ArrayVariable( + 'key2': ArrayAnyVariable( name='array', value=[ StringVariable(name='key5_1', value='value5_1'), @@ -136,13 +92,3 @@ def test_variable_to_object(): assert var.to_object() == 3.14 var = SecretVariable(name='secret', value='secret_value') assert var.to_object() == 'secret_value' - - -def test_build_a_object_variable_with_none_value(): - var = factory.build_anonymous_variable( - { - 'key1': None, - } - ) - assert isinstance(var, ObjectVariable) - assert isinstance(var.value['key1'], NoneVariable) diff --git a/api/tests/unit_tests/core/helper/__init__.py b/api/tests/unit_tests/core/helper/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py new file mode 100644 index 0000000000..d917bb1003 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -0,0 +1,52 @@ +import random +from unittest.mock import MagicMock, patch + +from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request + + +@patch('httpx.request') +def test_successful_request(mock_request): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_request.return_value = mock_response + + response = make_request('GET', 'http://example.com') + assert response.status_code == 200 + + +@patch('httpx.request') +def test_retry_exceed_max_retries(mock_request): + mock_response = MagicMock() + mock_response.status_code = 500 + + side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES + mock_request.side_effect = side_effects + + try: + make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) + raise AssertionError("Expected Exception not raised") + except Exception as e: + assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" + + +@patch('httpx.request') +def test_retry_logic_success(mock_request): + side_effects = [] + + for _ in range(SSRF_DEFAULT_MAX_RETRIES): + status_code = random.choice(STATUS_FORCELIST) + mock_response = MagicMock() + mock_response.status_code = status_code + side_effects.append(mock_response) + + mock_response_200 = MagicMock() + mock_response_200.status_code = 200 + side_effects.append(mock_response_200) + + mock_request.side_effect = side_effects + + response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES) + + assert response.status_code == 200 + assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 + assert mock_request.call_args_list[0][1].get('method') == 'GET' diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index c389461454..2237319904 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -21,6 +21,20 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: return str(tmp_path) +@pytest.fixture +def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent( + """\ + # - commented1 + # - commented2 + - + - + + """)) + return str(tmp_path) + + def test_position_helper(prepare_example_positions_yaml): position_map = get_position_map( folder_path=prepare_example_positions_yaml, @@ -32,3 +46,10 @@ def test_position_helper(prepare_example_positions_yaml): 'third': 2, 'forth': 3, } + + +def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml): + position_map = get_position_map( + folder_path=prepare_empty_commented_positions_yaml, + file_name='example_positions_all_commented.yaml') + assert position_map == {} diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index 446588cde1..c0452b4e4d 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -53,6 +53,9 @@ def test_load_yaml_non_existing_file(): assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} assert load_yaml_file(file_path='') == {} + with pytest.raises(FileNotFoundError): + load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False) + def test_load_valid_yaml_file(prepare_example_yaml_file): yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) @@ -68,7 +71,7 @@ def test_load_valid_yaml_file(prepare_example_yaml_file): def test_load_invalid_yaml_file(prepare_invalid_yaml_file): # yaml syntax error with pytest.raises(YAMLError): - load_yaml_file(file_path=prepare_invalid_yaml_file) + load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False) # ignore error - assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {} + assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {} diff --git a/docker-legacy/docker-compose.milvus.yaml b/docker-legacy/docker-compose.milvus.yaml index c422efbf4b..f4a7afa3a1 100644 --- a/docker-legacy/docker-compose.milvus.yaml +++ b/docker-legacy/docker-compose.milvus.yaml @@ -38,7 +38,7 @@ services: milvus-standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.3.1 + image: milvusdb/milvus:v2.4.6 command: ["milvus", "run", "standalone"] environment: ETCD_ENDPOINTS: etcd:2379 diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx index a1543230a9..11893ec9de 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx @@ -166,8 +166,8 @@ const ExtraInfo = ({ isMobile, relatedApps }: IExtraInfoProps) => { className='inline-flex items-center text-xs text-primary-600 mt-2 cursor-pointer' href={ locale === LanguagesSupported[1] - ? 'https://docs.dify.ai/v/zh-hans/guides/application-design/prompt-engineering' - : 'https://docs.dify.ai/user-guide/creating-dify-apps/prompt-engineering' + ? 'https://docs.dify.ai/v/zh-hans/guides/knowledge-base/integrate_knowledge_within_application' + : 'https://docs.dify.ai/guides/knowledge-base/integrate-knowledge-within-application' } target='_blank' rel='noopener noreferrer' > diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index a4f8b6839f..9ef624ed97 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -96,7 +96,7 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar ...rest, type: type === InputVarType.textInput ? 'string' : type, key: variable, - name: label, + name: label as string, } if (payload.type === InputVarType.textInput) diff --git a/web/app/components/app/configuration/config-voice/param-config-content.tsx b/web/app/components/app/configuration/config-voice/param-config-content.tsx index cced3b0458..9b0d5bbb69 100644 --- a/web/app/components/app/configuration/config-voice/param-config-content.tsx +++ b/web/app/components/app/configuration/config-voice/param-config-content.tsx @@ -31,12 +31,12 @@ const VoiceParamConfig: FC = () => { let languageItem = languages.find(item => item.value === textToSpeechConfig.language) const localLanguagePlaceholder = languageItem?.name || t('common.placeholder.select') - if (languages && !languageItem) + if (languages && !languageItem && languages.length > 0) languageItem = languages[0] const language = languageItem?.value const voiceItems = useSWR({ appId, language }, fetchAppVoices).data let voiceItem = voiceItems?.find(item => item.value === textToSpeechConfig.voice) - if (voiceItems && !voiceItem) + if (voiceItems && !voiceItem && voiceItems.length > 0) voiceItem = voiceItems[0] const localVoicePlaceholder = voiceItem?.name || t('common.placeholder.select') @@ -125,9 +125,11 @@ const VoiceParamConfig: FC = () => {
{t('appDebug.voice.voiceSettings.voice')}
{ + if (!value.value) + return setTextToSpeechConfig({ ...textToSpeechConfig, voice: String(value.value), diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index d71d26bbed..683617bf25 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -32,6 +32,7 @@ import { RerankingModeEnum } from '@/models/datasets' import cn from '@/utils/classnames' import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks' import Switch from '@/app/components/base/switch' +import { useGetLanguage } from '@/context/i18n' type Props = { datasetConfigs: DatasetConfigs @@ -43,6 +44,11 @@ type Props = { selectedDatasets?: DataSet[] } +const LEGACY_LINK_MAP = { + en_US: 'https://docs.dify.ai/guides/knowledge-base/integrate-knowledge-within-application', + zh_Hans: 'https://docs.dify.ai/v/zh-hans/guides/knowledge-base/integrate_knowledge_within_application', +} as Record + const ConfigContent: FC = ({ datasetConfigs, onChange, @@ -53,6 +59,7 @@ const ConfigContent: FC = ({ selectedDatasets = [], }) => { const { t } = useTranslation() + const language = useGetLanguage() const selectedDatasetsMode = useSelectedDatasetsMode(selectedDatasets) const type = datasetConfigs.retrieval_model const setType = (value: RETRIEVE_TYPE) => { @@ -167,7 +174,21 @@ const ConfigContent: FC = ({ title={(
{t('appDebug.datasetConfig.retrieveOneWay.title')} - {t('dataset.nTo1RetrievalLegacy')}
}> + + {t('dataset.nTo1RetrievalLegacy')} + + ({t('dataset.nTo1RetrievalLegacyLink')}) + + + )} + >
legacy
diff --git a/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx b/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx index 4c5db22513..72d617c3c3 100644 --- a/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx +++ b/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx @@ -41,6 +41,7 @@ const TextToSpeech: FC = () => { )} diff --git a/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx b/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx index 759a15213d..6fb58ba9a1 100644 --- a/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx +++ b/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx @@ -4,7 +4,6 @@ import React from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import I18n from '@/context/i18n' -import { FlipBackward } from '@/app/components/base/icons/src/vender/line/arrows' import { LanguagesSupported } from '@/i18n/language' type Props = { onReturnToSimpleMode: () => void @@ -38,7 +37,6 @@ const AdvancedModeWarning: FC = ({ onClick={onReturnToSimpleMode} className='shrink-0 flex items-center h-6 px-2 bg-indigo-600 shadow-xs border border-gray-200 rounded-lg text-white text-xs font-semibold cursor-pointer space-x-1' > -
{t('appDebug.promptMode.switchBack')}
= ({ const [childFeedback, setChildFeedback] = useState({ rating: null, }) + const { + config, + } = useChatContext() + const setCurrentLogItem = useAppStore(s => s.setCurrentLogItem) const setShowPromptLogModal = useAppStore(s => s.setShowPromptLogModal) @@ -430,6 +435,7 @@ const GenerationItem: FC = ({ )} diff --git a/web/app/components/base/action-button/index.css b/web/app/components/base/action-button/index.css index b87dbc8d4a..96fbb14c6c 100644 --- a/web/app/components/base/action-button/index.css +++ b/web/app/components/base/action-button/index.css @@ -16,7 +16,7 @@ } .action-btn-l { - @apply p-1.5 w-[34px] h-[34px] rounded-lg + @apply p-1.5 w-8 h-8 rounded-lg } /* m is for the regular button */ @@ -25,7 +25,7 @@ } .action-btn-xs { - @apply p-0 w-5 h-5 rounded + @apply p-0 w-4 h-4 rounded } .action-btn.action-btn-active { diff --git a/web/app/components/base/audio-btn/audio.player.manager.ts b/web/app/components/base/audio-btn/audio.player.manager.ts index 03e9e21f93..17d92f8dc2 100644 --- a/web/app/components/base/audio-btn/audio.player.manager.ts +++ b/web/app/components/base/audio-btn/audio.player.manager.ts @@ -41,7 +41,7 @@ export class AudioPlayerManager { } this.msgId = id - this.audioPlayers = new AudioPlayer(url, isPublic, id, msgContent, callback) + this.audioPlayers = new AudioPlayer(url, isPublic, id, msgContent, voice, callback) return this.audioPlayers } } diff --git a/web/app/components/base/audio-btn/audio.ts b/web/app/components/base/audio-btn/audio.ts index 638626bf8a..a61fd085d4 100644 --- a/web/app/components/base/audio-btn/audio.ts +++ b/web/app/components/base/audio-btn/audio.ts @@ -23,12 +23,13 @@ export default class AudioPlayer { isPublic: boolean callback: ((event: string) => {}) | null - constructor(streamUrl: string, isPublic: boolean, msgId: string | undefined, msgContent: string | null | undefined, callback: ((event: string) => {}) | null) { + constructor(streamUrl: string, isPublic: boolean, msgId: string | undefined, msgContent: string | null | undefined, voice: string | undefined, callback: ((event: string) => {}) | null) { this.audioContext = new AudioContext() this.msgId = msgId this.msgContent = msgContent this.url = streamUrl this.isPublic = isPublic + this.voice = voice this.callback = callback // Compatible with iphone ios17 ManagedMediaSource @@ -154,7 +155,6 @@ export default class AudioPlayer { this.mediaSource?.endOfStream() clearInterval(endTimer) } - console.log('finishStream endOfStream endTimer') }, 10) } @@ -169,7 +169,6 @@ export default class AudioPlayer { const arrayBuffer = this.cacheBuffers.shift()! this.sourceBuffer?.appendBuffer(arrayBuffer) } - console.log('finishStream timer') }, 10) } diff --git a/web/app/components/base/audio-btn/index.tsx b/web/app/components/base/audio-btn/index.tsx index 48081c170c..675f58b530 100644 --- a/web/app/components/base/audio-btn/index.tsx +++ b/web/app/components/base/audio-btn/index.tsx @@ -65,11 +65,11 @@ const AudioBtn = ({ } const handleToggle = async () => { if (audioState === 'playing' || audioState === 'loading') { - setAudioState('paused') + setTimeout(() => setAudioState('paused'), 1) AudioPlayerManager.getInstance().getAudioPlayer(url, isPublic, id, value, voice, audio_finished_call).pauseAudio() } else { - setAudioState('loading') + setTimeout(() => setAudioState('loading'), 1) AudioPlayerManager.getInstance().getAudioPlayer(url, isPublic, id, value, voice, audio_finished_call).playAudio() } } diff --git a/web/app/components/base/chat/chat/answer/operation.tsx b/web/app/components/base/chat/chat/answer/operation.tsx index d46aa34375..8ec5c0f3b2 100644 --- a/web/app/components/base/chat/chat/answer/operation.tsx +++ b/web/app/components/base/chat/chat/answer/operation.tsx @@ -53,10 +53,11 @@ const Operation: FC = ({ content: messageContent, annotation, feedback, + adminFeedback, agent_thoughts, } = item const hasAnnotation = !!annotation?.id - const [localFeedback, setLocalFeedback] = useState(feedback) + const [localFeedback, setLocalFeedback] = useState(config?.supportAnnotation ? adminFeedback : feedback) const content = useMemo(() => { if (agent_thoughts?.length) @@ -125,6 +126,7 @@ const Operation: FC = ({ id={id} value={content} noCache={false} + voice={config?.text_to_speech?.voice} className='hidden group-hover:block' /> diff --git a/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx b/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx index ea1d789d0a..a5a2eb7bb7 100644 --- a/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx +++ b/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx @@ -149,7 +149,7 @@ const VoiceParamConfig = ({
{t('appDebug.voice.voiceSettings.voice')}
{ handleChange({ diff --git a/web/app/components/base/notion-page-selector/notion-page-selector-modal/index.tsx b/web/app/components/base/notion-page-selector/notion-page-selector-modal/index.tsx index b120ef94b2..e7fba57056 100644 --- a/web/app/components/base/notion-page-selector/notion-page-selector-modal/index.tsx +++ b/web/app/components/base/notion-page-selector/notion-page-selector-modal/index.tsx @@ -2,15 +2,15 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' import { XMarkIcon } from '@heroicons/react/24/outline' import NotionPageSelector from '../base' -import type { NotionPageSelectorValue } from '../base' import s from './index.module.css' +import type { NotionPage } from '@/models/common' import cn from '@/utils/classnames' import Modal from '@/app/components/base/modal' type NotionPageSelectorModalProps = { isShow: boolean onClose: () => void - onSave: (selectedPages: NotionPageSelectorValue[]) => void + onSave: (selectedPages: NotionPage[]) => void datasetId: string } const NotionPageSelectorModal = ({ @@ -20,12 +20,12 @@ const NotionPageSelectorModal = ({ datasetId, }: NotionPageSelectorModalProps) => { const { t } = useTranslation() - const [selectedPages, setSelectedPages] = useState([]) + const [selectedPages, setSelectedPages] = useState([]) const handleClose = () => { onClose() } - const handleSelectPage = (newSelectedPages: NotionPageSelectorValue[]) => { + const handleSelectPage = (newSelectedPages: NotionPage[]) => { setSelectedPages(newSelectedPages) } const handleSave = () => { diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index 24da8855fa..dee983690b 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -191,7 +191,7 @@ const SimpleSelect: FC = ({ onClick={(e) => { e.stopPropagation() setSelectedItem(null) - onSelect({ value: null }) + onSelect({ name: '', value: '' }) }} className="h-5 w-5 text-gray-400 cursor-pointer" aria-hidden="false" diff --git a/web/app/components/explore/category.tsx b/web/app/components/explore/category.tsx index 2b6cfbd9be..cf655c5333 100644 --- a/web/app/components/explore/category.tsx +++ b/web/app/components/explore/category.tsx @@ -28,7 +28,7 @@ const Category: FC = ({ allCategoriesEn, }) => { const { t } = useTranslation() - const isAllCategories = !list.includes(value) + const isAllCategories = !list.includes(value as AppCategory) const itemClassName = (isSelected: boolean) => cn( 'flex items-center px-3 py-[7px] h-[32px] rounded-lg border-[0.5px] border-transparent text-gray-700 font-medium leading-[18px] cursor-pointer hover:bg-gray-200', diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index abc81262b9..1547032163 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -12,6 +12,7 @@ export enum FormTypeEnum { secretInput = 'secret-input', select = 'select', radio = 'radio', + boolean = 'boolean', files = 'files', } diff --git a/web/app/components/header/dataset-nav/index.tsx b/web/app/components/header/dataset-nav/index.tsx index f415658eee..abf76608a8 100644 --- a/web/app/components/header/dataset-nav/index.tsx +++ b/web/app/components/header/dataset-nav/index.tsx @@ -11,6 +11,7 @@ import useSWR from 'swr' import useSWRInfinite from 'swr/infinite' import { flatten } from 'lodash-es' import Nav from '../nav' +import type { NavItem } from '../nav/nav-selector' import { fetchDatasetDetail, fetchDatasets } from '@/service/datasets' import type { DataSetListResponse } from '@/models/datasets' @@ -31,7 +32,7 @@ const DatasetNav = () => { datasetId, } : null, - apiParams => fetchDatasetDetail(apiParams.datasetId)) + apiParams => fetchDatasetDetail(apiParams.datasetId as string)) const { data: datasetsData, setSize } = useSWRInfinite(datasetId ? getKey : () => null, fetchDatasets, { revalidateFirstPage: false, revalidateAll: true }) const datasetItems = flatten(datasetsData?.map(datasetData => datasetData.data)) @@ -46,14 +47,14 @@ const DatasetNav = () => { text={t('common.menus.datasets')} activeSegment='datasets' link='/datasets' - curNav={currentDataset} + curNav={currentDataset as Omit} navs={datasetItems.map(dataset => ({ id: dataset.id, name: dataset.name, link: `/datasets/${dataset.id}/documents`, icon: dataset.icon, icon_background: dataset.icon_background, - }))} + })) as NavItem[]} createText={t('common.menus.newDataset')} onCreate={() => router.push('/datasets/create')} onLoadmore={handleLoadmore} diff --git a/web/app/components/header/nav/nav-selector/index.tsx b/web/app/components/header/nav/nav-selector/index.tsx index 51192c6580..26f538d72d 100644 --- a/web/app/components/header/nav/nav-selector/index.tsx +++ b/web/app/components/header/nav/nav-selector/index.tsx @@ -23,13 +23,13 @@ export type NavItem = { link: string icon: string icon_background: string - mode: string + mode?: string } export type INavSelectorProps = { navs: NavItem[] curNav?: Omit createText: string - isApp: boolean + isApp?: boolean onCreate: (state: string) => void onLoadmore?: () => void } diff --git a/web/app/components/tools/provider/card.tsx b/web/app/components/tools/provider/card.tsx index 7f87d65e3a..6a688186cf 100644 --- a/web/app/components/tools/provider/card.tsx +++ b/web/app/components/tools/provider/card.tsx @@ -36,7 +36,7 @@ const ProviderCard = ({ }, [collection.labels, labelList, language]) return ( -
+
{typeof collection.icon === 'string' && ( diff --git a/web/app/components/tools/provider/detail.tsx b/web/app/components/tools/provider/detail.tsx index ee02e4966d..546b9cd9a1 100644 --- a/web/app/components/tools/provider/detail.tsx +++ b/web/app/components/tools/provider/detail.tsx @@ -85,7 +85,7 @@ const ProviderDetail = ({ const [customCollection, setCustomCollection] = useState(null) const [isShowEditCollectionToolModal, setIsShowEditCustomCollectionModal] = useState(false) const [showConfirmDelete, setShowConfirmDelete] = useState(false) - const [deleteAction, setDeleteAction] = useState(null) + const [deleteAction, setDeleteAction] = useState('') const doUpdateCustomToolCollection = async (data: CustomCollectionBackend) => { await updateCustomCollection(data) onRefreshData() diff --git a/web/app/components/tools/workflow-tool/index.tsx b/web/app/components/tools/workflow-tool/index.tsx index 436b2c55ab..0f9fe4c4c1 100644 --- a/web/app/components/tools/workflow-tool/index.tsx +++ b/web/app/components/tools/workflow-tool/index.tsx @@ -173,7 +173,7 @@ const WorkflowToolAsModal: FC = ({
{t('tools.createTool.description')}