diff --git a/.devcontainer/README.md b/.devcontainer/README.md
index fa989584f5..df12a3c2d6 100644
--- a/.devcontainer/README.md
+++ b/.devcontainer/README.md
@@ -1,4 +1,4 @@
-# Devlopment with devcontainer
+# Development with devcontainer
This project includes a devcontainer configuration that allows you to open the project in a container with a fully configured development environment.
Both frontend and backend environments are initialized when the container is started.
## GitHub Codespaces
@@ -33,5 +33,5 @@ Performance Impact: While usually minimal, programs running inside a devcontaine
if you see such error message when you open this project in codespaces:

-a simple workaround is change `/signin` endpoint into another one, then login with github account and close the tab, then change it back to `/signin` endpoint. Then all things will be fine.
+a simple workaround is change `/signin` endpoint into another one, then login with GitHub account and close the tab, then change it back to `/signin` endpoint. Then all things will be fine.
The reason is `signin` endpoint is not allowed in codespaces, details can be found [here](https://github.com/orgs/community/discussions/5204)
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index fea45de1d3..b596bdb6b0 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -8,13 +8,13 @@ body:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- - label: This is only for bug report, if you would like to ask a quesion, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
+ - label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- - label: "Pleas do not modify this template :) and fill in all the required fields."
+ - label: "Please do not modify this template :) and fill in all the required fields."
required: true
- type: input
diff --git a/.github/ISSUE_TEMPLATE/document_issue.yml b/.github/ISSUE_TEMPLATE/document_issue.yml
index 44115b2097..c5aeb7fd73 100644
--- a/.github/ISSUE_TEMPLATE/document_issue.yml
+++ b/.github/ISSUE_TEMPLATE/document_issue.yml
@@ -1,7 +1,7 @@
name: "📚 Documentation Issue"
description: Report issues in our documentation
labels:
- - ducumentation
+ - documentation
body:
- type: checkboxes
attributes:
@@ -12,7 +12,7 @@ body:
required: true
- label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- - label: "Pleas do not modify this template :) and fill in all the required fields."
+ - label: "Please do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
attributes:
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
index 694bd3975d..8730f5c11f 100644
--- a/.github/ISSUE_TEMPLATE/feature_request.yml
+++ b/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -12,7 +12,7 @@ body:
required: true
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- - label: "Pleas do not modify this template :) and fill in all the required fields."
+ - label: "Please do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
attributes:
diff --git a/.github/ISSUE_TEMPLATE/translation_issue.yml b/.github/ISSUE_TEMPLATE/translation_issue.yml
index 589e071e14..898e2cdf58 100644
--- a/.github/ISSUE_TEMPLATE/translation_issue.yml
+++ b/.github/ISSUE_TEMPLATE/translation_issue.yml
@@ -12,7 +12,7 @@ body:
required: true
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- - label: "Pleas do not modify this template :) and fill in all the required fields."
+ - label: "Please do not modify this template :) and fill in all the required fields."
required: true
- type: input
attributes:
diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml
index a0407de843..82cfdcea06 100644
--- a/.github/workflows/api-tests.yml
+++ b/.github/workflows/api-tests.yml
@@ -4,9 +4,17 @@ on:
pull_request:
branches:
- main
+ paths:
+ - api/**
+ - docker/**
+
+concurrency:
+ group: api-tests-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
jobs:
test:
+ name: API Tests
runs-on: ubuntu-latest
strategy:
matrix:
@@ -46,11 +54,12 @@ jobs:
docker/docker-compose.middleware.yaml
services: |
sandbox
+ ssrf_proxy
- name: Run Workflow
run: dev/pytest/pytest_workflow.sh
- - name: Set up Vector Stores (Weaviate, Qdrant, Milvus, PgVecto-RS)
+ - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
@@ -58,6 +67,7 @@ jobs:
docker/docker-compose.qdrant.yaml
docker/docker-compose.milvus.yaml
docker/docker-compose.pgvecto-rs.yaml
+ docker/docker-compose.pgvector.yaml
services: |
weaviate
qdrant
@@ -65,6 +75,7 @@ jobs:
minio
milvus-standalone
pgvecto-rs
+ pgvector
- name: Test Vector Stores
run: dev/pytest/pytest_vdb.sh
diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml
new file mode 100644
index 0000000000..cb8dd06c5e
--- /dev/null
+++ b/.github/workflows/db-migration-test.yml
@@ -0,0 +1,53 @@
+name: DB Migration Test
+
+on:
+ pull_request:
+ branches:
+ - main
+ paths:
+ - api/migrations/**
+
+concurrency:
+ group: db-migration-test-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ db-migration-test:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version:
+ - "3.10"
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: |
+ ./api/requirements.txt
+
+ - name: Install dependencies
+ run: pip install -r ./api/requirements.txt
+
+ - name: Set up Middleware
+ uses: hoverkraft-tech/compose-action@v2.0.0
+ with:
+ compose-file: |
+ docker/docker-compose.middleware.yaml
+ services: |
+ db
+
+ - name: Prepare configs
+ run: |
+ cd api
+ cp .env.example .env
+
+ - name: Run DB Migration
+ run: |
+ cd api
+ flask db upgrade
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index bdbc22b489..7dad707fea 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -6,7 +6,7 @@ on:
- main
concurrency:
- group: dep-${{ github.head_ref || github.run_id }}
+ group: style-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
@@ -18,54 +18,89 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
+ - name: Check changed files
+ id: changed-files
+ uses: tj-actions/changed-files@v44
+ with:
+ files: api/**
+
- name: Set up Python
uses: actions/setup-python@v5
+ if: steps.changed-files.outputs.any_changed == 'true'
with:
python-version: '3.10'
- name: Python dependencies
+ if: steps.changed-files.outputs.any_changed == 'true'
run: pip install ruff dotenv-linter
- name: Ruff check
- run: ruff check ./api
+ if: steps.changed-files.outputs.any_changed == 'true'
+ run: ruff check --preview ./api
- name: Dotenv check
+ if: steps.changed-files.outputs.any_changed == 'true'
run: dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
- test:
- name: ESLint and SuperLinter
+ web-style:
+ name: Web Style
runs-on: ubuntu-latest
- needs: python-style
+ defaults:
+ run:
+ working-directory: ./web
steps:
- name: Checkout code
uses: actions/checkout@v4
+
+ - name: Check changed files
+ id: changed-files
+ uses: tj-actions/changed-files@v44
with:
- fetch-depth: 0
+ files: web/**
- name: Setup NodeJS
uses: actions/setup-node@v4
+ if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 20
cache: yarn
cache-dependency-path: ./web/package.json
- name: Web dependencies
- run: |
- cd ./web
- yarn install --frozen-lockfile
+ if: steps.changed-files.outputs.any_changed == 'true'
+ run: yarn install --frozen-lockfile
- name: Web style check
- run: |
- cd ./web
- yarn run lint
+ if: steps.changed-files.outputs.any_changed == 'true'
+ run: yarn run lint
+
+
+ superlinter:
+ name: SuperLinter
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Check changed files
+ id: changed-files
+ uses: tj-actions/changed-files@v44
+ with:
+ files: |
+ **.sh
+ **.yaml
+ **.yml
+ Dockerfile
- name: Super-linter
uses: super-linter/super-linter/slim@v6
+ if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning
DEFAULT_BRANCH: main
@@ -76,4 +111,5 @@ jobs:
VALIDATE_BASH_EXEC: true
VALIDATE_GITHUB_ACTIONS: true
VALIDATE_DOCKERFILE_HADOLINT: true
+ VALIDATE_XML: true
VALIDATE_YAML: true
diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml
index 575ead4b3b..fb4bcb9d66 100644
--- a/.github/workflows/tool-test-sdks.yaml
+++ b/.github/workflows/tool-test-sdks.yaml
@@ -4,6 +4,13 @@ on:
pull_request:
branches:
- main
+ paths:
+ - sdks/**
+
+concurrency:
+ group: sdk-tests-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
jobs:
build:
name: unit test for Node.js SDK
diff --git a/.gitignore b/.gitignore
index c957d63174..a51465efbc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -134,7 +134,8 @@ dmypy.json
web/.vscode/settings.json
# Intellij IDEA Files
-.idea/
+.idea/*
+!.idea/vcs.xml
.ideaDataSources/
api/.env
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000000..ae8b1755c5
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,16 @@
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md
new file mode 100644
index 0000000000..c9329d6102
--- /dev/null
+++ b/CONTRIBUTING_JA.md
@@ -0,0 +1,160 @@
+Dify にコントリビュートしたいとお考えなのですね。それは素晴らしいことです。
+私たちは、LLM アプリケーションの構築と管理のための最も直感的なワークフローを設計するという壮大な野望を持っています。人数も資金も限られている新興企業として、コミュニティからの支援は本当に重要です。
+
+私たちは現状を鑑み、機敏かつ迅速に開発をする必要がありますが、同時にあなたのようなコントリビューターの方々に、可能な限りスムーズな貢献体験をしていただきたいと思っています。そのためにこのコントリビュートガイドを作成しました。
+コードベースやコントリビュータの方々と私たちがどのように仕事をしているのかに慣れていただき、楽しいパートにすぐに飛び込めるようにすることが目的です。
+
+このガイドは Dify そのものと同様に、継続的に改善されています。実際のプロジェクトに遅れをとることがあるかもしれませんが、ご理解をお願いします。
+
+ライセンスに関しては、私たちの短い[ライセンスおよびコントリビューター規約](./LICENSE)をお読みください。また、コミュニティは[行動規範](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)を遵守しています。
+
+## 飛び込む前に
+
+[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
+
+### 機能リクエスト
+
+* 新しい機能要望を出す場合は、提案する機能が何を実現するものなのかを説明し、可能な限り多くの文脈を含めてください。[@perzeusss](https://github.com/perzeuss)は、あなたの要望を書き出すのに役立つ [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) を作ってくれました。気軽に試してみてください。
+
+* 既存の課題から 1 つ選びたい場合は、その下にコメントを書いてください。
+
+ 関連する方向で作業しているチームメンバーが参加します。すべてが良好であれば、コーディングを開始する許可が与えられます。私たちが変更を提案した場合にあなたの作業が無駄になることがないよう、それまでこの機能の作業を控えていただくようお願いいたします。
+
+ 提案された機能がどの分野に属するかによって、あなたは異なるチーム・メンバーと話をするかもしれません。以下は、各チームメンバーが現在取り組んでいる分野の概要です。
+
+| Member | Scope |
+| --------------------------------------------------------------------------------------- | ------------------------------------ |
+| [@yeuoly](https://github.com/Yeuoly) | エージェントアーキテクチャ |
+| [@jyong](https://github.com/JohnJyong) | RAG パイプライン設計 |
+| [@GarfieldDai](https://github.com/GarfieldDai) | workflow orchestrations の構築 |
+| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | フロントエンドを使いやすくする |
+| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 開発者体験、何でも相談できる窓口 |
+| [@takatost](https://github.com/takatost) | 全体的な製品の方向性とアーキテクチャ |
+
+優先順位の付け方:
+
+| Feature Type | Priority |
+| --------------------------------------------------------------------------------------------------------------------- | --------------- |
+| チームメンバーによってラベル付けされた優先度の高い機能 | High Priority |
+| [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks)の人気の機能リクエスト | Medium Priority |
+| 非コア機能とマイナーな機能強化 | Low Priority |
+| 価値はあるが即効性はない | Future-Feature |
+
+### その他 (バグレポート、パフォーマンスの最適化、誤字の修正など)
+
+* すぐにコーディングを始めてください
+
+優先順位の付け方:
+
+| Issue Type | Priority |
+| -------------------------------------------------------------------------------------- | --------------- |
+| コア機能のバグ(ログインできない、アプリケーションが動作しない、セキュリティの抜け穴) | Critical |
+| 致命的でないバグ、パフォーマンス向上 | Medium Priority |
+| 細かな修正(誤字脱字、機能はするが分かりにくい UI) | Low Priority |
+
+## インストール
+
+Dify を開発用にセットアップする手順は以下の通りです。
+
+### 1. このリポジトリをフォークする
+
+### 2. リポジトリをクローンする
+
+フォークしたリポジトリをターミナルからクローンします。
+
+```
+git clone git@github.com:/dify.git
+```
+
+### 3. 依存関係の確認
+
+Dify を構築するには次の依存関係が必要です。それらがシステムにインストールされていることを確認してください。
+
+- [Docker](https://www.docker.com/)
+- [Docker Compose](https://docs.docker.com/compose/install/)
+- [Node.js v18.x (LTS)](http://nodejs.org)
+- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
+- [Python](https://www.python.org/) version 3.10.x
+
+### 4. インストール
+
+Dify はバックエンドとフロントエンドから構成されています。
+まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。
+次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。
+
+よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) を確認してください。
+
+### 5. ブラウザで dify にアクセスする
+
+設定を確認するために、ブラウザで[http://localhost:3000](http://localhost:3000)(デフォルト、または自分で設定した URL とポート)にアクセスしてください。Dify が起動して実行中であることが確認できるはずです。
+
+## 開発中
+
+モデルプロバイダーを追加する場合は、[このガイド](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)が役立ちます。
+
+Agent や Workflow にツールプロバイダーを追加する場合は、[このガイド](./api/core/tools/README.md)が役立ちます。
+
+Dify のバックエンドとフロントエンドの概要を簡単に説明します。
+
+### バックエンド
+
+Dify のバックエンドは[Flask](https://flask.palletsprojects.com/en/3.0.x/)を使って Python で書かれています。ORM には[SQLAlchemy](https://www.sqlalchemy.org/)を、タスクキューには[Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html)を使っています。認証ロジックは Flask-login 経由で行われます。
+
+```
+[api/]
+├── constants // コードベース全体で使用される定数設定
+├── controllers // APIルート定義とリクエスト処理ロジック
+├── core // アプリケーションの中核的な管理、モデル統合、およびツール
+├── docker // Dockerおよびコンテナ関連の設定
+├── events // イベントのハンドリングと処理
+├── extensions // 第三者のフレームワーク/プラットフォームとの拡張
+├── fields // シリアライゼーション/マーシャリング用のフィールド定義
+├── libs // 再利用可能なライブラリとヘルパー
+├── migrations // データベースマイグレーションスクリプト
+├── models // データベースモデルとスキーマ定義
+├── services // ビジネスロジックの定義
+├── storage // 秘密鍵の保存
+├── tasks // 非同期タスクとバックグラウンドジョブの処理
+└── tests // テスト関連のファイル
+```
+
+### フロントエンド
+
+このウェブサイトは、Typescript の[Next.js](https://nextjs.org/)ボイラープレートでブートストラップされており、スタイリングには[Tailwind CSS](https://tailwindcss.com/)を使用しています。国際化には[React-i18next](https://react.i18next.com/)を使用しています。
+
+```
+[web/]
+├── app // レイアウト、ページ、コンポーネント
+│ ├── (commonLayout) // アプリ全体で共通のレイアウト
+│ ├── (shareLayout) // トークン特有のセッションで共有されるレイアウト
+│ ├── activate // アクティベートページ
+│ ├── components // ページやレイアウトで共有されるコンポーネント
+│ ├── install // インストールページ
+│ ├── signin // サインインページ
+│ └── styles // グローバルに共有されるスタイル
+├── assets // 静的アセット
+├── bin // ビルドステップで実行されるスクリプト
+├── config // 調整可能な設定とオプション
+├── context // アプリの異なる部分で使用される共有コンテキスト
+├── dictionaries // 言語別の翻訳ファイル
+├── docker // コンテナ設定
+├── hooks // 再利用可能なフック
+├── i18n // 国際化設定
+├── models // データモデルとAPIレスポンスの形状を記述
+├── public // ファビコンなどのメタアセット
+├── service // APIアクションの形状を指定
+├── test
+├── types // 関数のパラメータと戻り値の記述
+└── utils // 共有ユーティリティ関数
+```
+
+## PR を投稿する
+
+いよいよ、私たちのリポジトリにプルリクエスト (PR) を提出する時が来ました。主要な機能については、まず `deploy/dev` ブランチにマージしてテストしてから `main` ブランチにマージします。
+マージ競合などの問題が発生した場合、またはプル リクエストを開く方法がわからない場合は、[GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests) をチェックしてみてください。
+
+これで完了です!あなたの PR がマージされると、[README](https://github.com/langgenius/dify/blob/main/README.md) にコントリビューターとして紹介されます。
+
+## ヘルプを得る
+
+コントリビュート中に行き詰まったり、疑問が生じたりした場合は、GitHub の関連する issue から質問していただくか、[Discord](https://discord.gg/8Tpq4AcN9c)でチャットしてください。
diff --git a/README.md b/README.md
index 0dabb16c67..c43b52d7ad 100644
--- a/README.md
+++ b/README.md
@@ -35,13 +35,10 @@
+
-#
-
-
-
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
@@ -109,7 +106,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
Agent
✅
✅
-
✅
+
❌
✅
@@ -127,7 +124,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
@@ -242,4 +243,4 @@ docker compose up -d
## ライセンス
-このリポジトリは、Dify Open Source License にいくつかの追加制限を加えた[Difyオープンソースライセンス](LICENSE)の下で利用可能です。
\ No newline at end of file
+このリポジトリは、Dify Open Source License にいくつかの追加制限を加えた[Difyオープンソースライセンス](LICENSE)の下で利用可能です。
diff --git a/README_KL.md b/README_KL.md
index 600649c459..b1eb5073f6 100644
--- a/README_KL.md
+++ b/README_KL.md
@@ -35,6 +35,7 @@
+
#
@@ -111,7 +112,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
+
+
+ Dify는 오픈 소스 LLM 앱 개발 플랫폼입니다. 직관적인 인터페이스를 통해 AI 워크플로우, RAG 파이프라인, 에이전트 기능, 모델 관리, 관찰 기능 등을 결합하여 프로토타입에서 프로덕션까지 빠르게 전환할 수 있습니다. 주요 기능 목록은 다음과 같습니다:
+
+**1. 워크플로우**:
+ 다음 기능들을 비롯한 다양한 기능을 활용하여 시각적 캔버스에서 강력한 AI 워크플로우를 구축하고 테스트하세요.
+
+
+ https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
+
+
+
+**2. 포괄적인 모델 지원:**:
+
+수십 개의 추론 제공업체와 자체 호스팅 솔루션에서 제공하는 수백 개의 독점 및 오픈 소스 LLM과 원활하게 통합되며, GPT, Mistral, Llama3 및 모든 OpenAI API 호환 모델을 포함합니다. 지원되는 모델 제공업체의 전체 목록은 [여기](https://docs.dify.ai/getting-started/readme/model-providers)에서 확인할 수 있습니다.
+
+
+
+**3. 통합 개발환경**:
+ 프롬프트를 작성하고, 모델 성능을 비교하며, 텍스트-음성 변환과 같은 추가 기능을 채팅 기반 앱에 추가할 수 있는 직관적인 인터페이스를 제공합니다.
+
+**4. RAG 파이프라인**:
+ 문서 수집부터 검색까지 모든 것을 다루며, PDF, PPT 및 기타 일반적인 문서 형식에서 텍스트 추출을 위한 기본 지원이 포함되어 있는 광범위한 RAG 기능을 제공합니다.
+
+**5. 에이전트 기능**:
+ LLM 함수 호출 또는 ReAct를 기반으로 에이전트를 정의하고 에이전트에 대해 사전 구축된 도구나 사용자 정의 도구를 추가할 수 있습니다. Dify는 Google Search, DELL·E, Stable Diffusion, WolframAlpha 등 AI 에이전트를 위한 50개 이상의 내장 도구를 제공합니다.
+
+**6. LLMOps**:
+ 시간 경과에 따른 애플리케이션 로그와 성능을 모니터링하고 분석합니다. 생산 데이터와 주석을 기반으로 프롬프트, 데이터세트, 모델을 지속적으로 개선할 수 있습니다.
+
+**7. Backend-as-a-Service**:
+ Dify의 모든 제품에는 해당 API가 함께 제공되므로 Dify를 자신의 비즈니스 로직에 쉽게 통합할 수 있습니다.
+
+## 기능 비교
+
+
+
기능
+
Dify.AI
+
LangChain
+
Flowise
+
OpenAI Assistants API
+
+
+
프로그래밍 접근 방식
+
API + 앱 중심
+
Python 코드
+
앱 중심
+
API 중심
+
+
+
지원되는 LLMs
+
다양한 종류
+
다양한 종류
+
다양한 종류
+
OpenAI 전용
+
+
+
RAG 엔진
+
✅
+
✅
+
✅
+
✅
+
+
+
에이전트
+
✅
+
✅
+
❌
+
✅
+
+
+
워크플로우
+
✅
+
❌
+
✅
+
❌
+
+
+
가시성
+
✅
+
✅
+
❌
+
❌
+
+
+
기업용 기능 (SSO/접근 제어)
+
✅
+
❌
+
❌
+
❌
+
+
+
로컬 배포
+
✅
+
✅
+
✅
+
❌
+
+
+
+## Dify 사용하기
+
+- **클라우드 **
+ 우리는 누구나 설정이 필요 없이 사용해 볼 수 있도록 [Dify 클라우드](https://dify.ai) 서비스를 호스팅합니다. 이는 자체 배포 버전의 모든 기능을 제공하며, 샌드박스 플랜에서 무료로 200회의 GPT-4 호출을 포함합니다.
+
+- **셀프-호스팅 Dify 커뮤니티 에디션**
+ 환경에서 Dify를 빠르게 실행하려면 이 [스타터 가이드를](#quick-start) 참조하세요.
+ 추가 참조 및 더 심층적인 지침은 [문서](https://docs.dify.ai)를 사용하세요.
+
+- **기업 / 조직을 위한 Dify**
+ 우리는 추가적인 기업 중심 기능을 제공합니다. 당사와 [미팅일정](https://cal.com/guchenhe/30min)을 잡거나 [이메일 보내기](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)를 통해 기업 요구 사항을 논의하십시오.
+ > AWS를 사용하는 스타트업 및 중소기업의 경우 [AWS Marketplace에서 Dify Premium](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)을 확인하고 한 번의 클릭으로 자체 AWS VPC에 배포하십시오. 맞춤형 로고와 브랜딩이 포함된 앱을 생성할 수 있는 옵션이 포함된 저렴한 AMI 제품입니다.
+
+
+
+## 앞서가기
+
+GitHub에서 Dify에 별표를 찍어 새로운 릴리스를 즉시 알림 받으세요.
+
+
+
+
+
+## 빠른 시작
+>Dify를 설치하기 전에 컴퓨터가 다음과 같은 최소 시스템 요구 사항을 충족하는지 확인하세요 :
+>- CPU >= 2 Core
+>- RAM >= 4GB
+
+
+
+Dify 서버를 시작하는 가장 쉬운 방법은 [docker-compose.yml](docker/docker-compose.yaml) 파일을 실행하는 것입니다. 설치 명령을 실행하기 전에 [Docker](https://docs.docker.com/get-docker/) 및 [Docker Compose](https://docs.docker.com/compose/install/)가 머신에 설치되어 있는지 확인하세요.
+
+```bash
+cd docker
+docker compose up -d
+```
+
+실행 후 브라우저의 [http://localhost/install](http://localhost/install) 에서 Dify 대시보드에 액세스하고 초기화 프로세스를 시작할 수 있습니다.
+
+> Dify에 기여하거나 추가 개발을 하고 싶다면 소스 코드에서 [배포에 대한 가이드](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)를 참조하세요.
+
+## 다음 단계
+
+구성 커스터마이징이 필요한 경우, [docker-compose.yml](docker/docker-compose.yaml) 파일의 코멘트를 참조하여 환경 구성을 수동으로 설정하십시오. 변경 후 `docker-compose up -d` 를 다시 실행하십시오. 환경 변수의 전체 목록은 [여기](https://docs.dify.ai/getting-started/install-self-hosted/environments)에서 확인할 수 있습니다.
+
+
+고가용성 설정을 구성하려면 Dify를 Kubernetes에 배포할 수 있는 커뮤니티 제공 [Helm Charts](https://helm.sh/)가 있습니다.
+
+- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
+- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
+
+
+## 기여
+
+코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.
+동시에 Dify를 소셜 미디어와 행사 및 컨퍼런스에 공유하여 지원하는 것을 고려해 주시기 바랍니다.
+
+
+> 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요.
+
+**기여자**
+
+
+
+
+
+## 커뮤니티 & 연락처
+
+* [Github 토론](https://github.com/langgenius/dify/discussions). 피드백 공유 및 질문하기에 적합합니다.
+* [GitHub 이슈](https://github.com/langgenius/dify/issues). Dify.AI 사용 중 발견한 버그와 기능 제안에 적합합니다. [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.
+* [이메일](mailto:support@dify.ai?subject=[GitHub]Questions%20About%20Dify). Dify.AI 사용에 대한 질문하기에 적합합니다.
+* [디스코드](https://discord.gg/FngNHpbcY7). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다.
+* [트위터](https://twitter.com/dify_ai). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다.
+
+또는 팀원과 직접 미팅을 예약하세요:
+
+
+
+
연락처
+
목적
+
+
+
+
비즈니스 문의 및 제품 피드백
+
+
+
+
기여, 이슈 및 기능 요청
+
+
+
+## Star 히스토리
+
+[](https://star-history.com/#langgenius/dify&Date)
+
+
+## 보안 공개
+
+개인정보 보호를 위해 보안 문제를 GitHub에 게시하지 마십시오. 대신 security@dify.ai로 질문을 보내주시면 더 자세한 답변을 드리겠습니다.
+
+## 라이선스
+
+이 저장소는 기본적으로 몇 가지 추가 제한 사항이 있는 Apache 2.0인 [Dify 오픈 소스 라이선스](LICENSE)에 따라 사용할 수 있습니다.
diff --git a/api/.env.example b/api/.env.example
index 30bbf331a4..f112721a7e 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -17,6 +17,9 @@ APP_WEB_URL=http://127.0.0.1:3000
# Files URL
FILES_URL=http://127.0.0.1:5001
+# The time in seconds after the signature is rejected
+FILES_ACCESS_TIMEOUT=300
+
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
@@ -65,7 +68,7 @@ GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-stri
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
-# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs
+# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector
VECTOR_STORE=weaviate
# Weaviate configuration
@@ -102,6 +105,20 @@ PGVECTO_RS_USER=postgres
PGVECTO_RS_PASSWORD=difyai123456
PGVECTO_RS_DATABASE=postgres
+# PGVector configuration
+PGVECTOR_HOST=127.0.0.1
+PGVECTOR_PORT=5433
+PGVECTOR_USER=postgres
+PGVECTOR_PASSWORD=postgres
+PGVECTOR_DATABASE=postgres
+
+# Tidb Vector configuration
+TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
+TIDB_VECTOR_PORT=4000
+TIDB_VECTOR_USER=xxx.root
+TIDB_VECTOR_PASSWORD=xxxxxx
+TIDB_VECTOR_DATABASE=dify
+
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
@@ -117,10 +134,11 @@ RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
# smtp configuration
SMTP_SERVER=smtp.gmail.com
-SMTP_PORT=587
+SMTP_PORT=465
SMTP_USERNAME=123
SMTP_PASSWORD=abc
-SMTP_USE_TLS=false
+SMTP_USE_TLS=true
+SMTP_OPPORTUNISTIC_TLS=false
# Sentry configuration
SENTRY_DSN=
@@ -137,6 +155,7 @@ NOTION_INTERNAL_SECRET=you-internal-secret
ETL_TYPE=dify
UNSTRUCTURED_API_URL=
+UNSTRUCTURED_API_KEY=
SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL=
@@ -163,6 +182,16 @@ API_TOOL_DEFAULT_READ_TIMEOUT=60
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
HTTP_REQUEST_MAX_READ_TIMEOUT=600
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
+HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 # 10MB
+HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # 1MB
# Log file path
LOG_FILE=
+
+# Indexing configuration
+INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
+
+# Workflow runtime configuration
+WORKFLOW_MAX_EXECUTION_STEPS=500
+WORKFLOW_MAX_EXECUTION_TIME=1200
+WORKFLOW_CALL_MAX_DEPTH=5
diff --git a/api/commands.py b/api/commands.py
index b82f7ac3f8..186b97c3fa 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -1,11 +1,13 @@
import base64
import json
import secrets
+from typing import Optional
import click
from flask import current_app
from werkzeug.exceptions import NotFound
+from constants.languages import languages
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
@@ -17,6 +19,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel
+from services.account_service import RegisterService, TenantService
@click.command('reset-password', help='Reset the account password.')
@@ -57,7 +60,7 @@ def reset_password(email, new_password, password_confirm):
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
- click.echo(click.style('Congratulations!, password has been reset.', fg='green'))
+ click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
@click.command('reset-email', help='Reset the account email.')
@@ -305,6 +308,14 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
+ elif vector_type == "pgvector":
+ dataset_id = dataset.id
+ collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+ index_struct_dict = {
+ "type": 'pgvector',
+ "vector_store": {"class_prefix": collection_name}
+ }
+ dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
@@ -440,9 +451,105 @@ def convert_to_agent_apps():
click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green'))
+@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.')
+@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.')
+def add_qdrant_doc_id_index(field: str):
+ click.echo(click.style('Start add qdrant doc_id index.', fg='green'))
+ config = current_app.config
+ vector_type = config.get('VECTOR_STORE')
+ if vector_type != "qdrant":
+ click.echo(click.style('Sorry, only support qdrant vector store.', fg='red'))
+ return
+ create_count = 0
+
+ try:
+ bindings = db.session.query(DatasetCollectionBinding).all()
+ if not bindings:
+ click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red'))
+ return
+ import qdrant_client
+ from qdrant_client.http.exceptions import UnexpectedResponse
+ from qdrant_client.http.models import PayloadSchemaType
+
+ from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
+ for binding in bindings:
+ qdrant_config = QdrantConfig(
+ endpoint=config.get('QDRANT_URL'),
+ api_key=config.get('QDRANT_API_KEY'),
+ root_path=current_app.root_path,
+ timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
+ grpc_port=config.get('QDRANT_GRPC_PORT'),
+ prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
+ )
+ try:
+ client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
+ # create payload index
+ client.create_payload_index(binding.collection_name, field,
+ field_schema=PayloadSchemaType.KEYWORD)
+ create_count += 1
+ except UnexpectedResponse as e:
+ # Collection does not exist, so return
+ if e.status_code == 404:
+ click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red'))
+ continue
+ # Some other error occurred, so re-raise the exception
+ else:
+ click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red'))
+
+ except Exception as e:
+ click.echo(click.style('Failed to create qdrant client.', fg='red'))
+
+ click.echo(
+ click.style(f'Congratulations! Create {create_count} collection indexes.',
+ fg='green'))
+
+
+@click.command('create-tenant', help='Create account and tenant.')
+@click.option('--email', prompt=True, help='The email address of the tenant account.')
+@click.option('--language', prompt=True, help='Account language, default: en-US.')
+def create_tenant(email: str, language: Optional[str] = None):
+ """
+ Create tenant account
+ """
+ if not email:
+ click.echo(click.style('Sorry, email is required.', fg='red'))
+ return
+
+ # Create account
+ email = email.strip()
+
+ if '@' not in email:
+ click.echo(click.style('Sorry, invalid email address.', fg='red'))
+ return
+
+ account_name = email.split('@')[0]
+
+ if language not in languages:
+ language = 'en-US'
+
+ # generate random password
+ new_password = secrets.token_urlsafe(16)
+
+ # register account
+ account = RegisterService.register(
+ email=email,
+ name=account_name,
+ password=new_password,
+ language=language
+ )
+
+ TenantService.create_owner_tenant_if_not_exist(account)
+
+ click.echo(click.style('Congratulations! Account and tenant created.\n'
+ 'Account: {}\nPassword: {}'.format(email, new_password), fg='green'))
+
+
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(vdb_migrate)
app.cli.add_command(convert_to_agent_apps)
+ app.cli.add_command(add_qdrant_doc_id_index)
+ app.cli.add_command(create_tenant)
+
diff --git a/api/config.py b/api/config.py
index d6c345c579..286b3336a2 100644
--- a/api/config.py
+++ b/api/config.py
@@ -23,14 +23,17 @@ DEFAULTS = {
'SERVICE_API_URL': 'https://api.dify.ai',
'APP_WEB_URL': 'https://udify.app',
'FILES_URL': '',
+ 'FILES_ACCESS_TIMEOUT': 300,
'S3_ADDRESS_STYLE': 'auto',
'STORAGE_TYPE': 'local',
'STORAGE_LOCAL_PATH': 'storage',
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
'DEPLOY_ENV': 'PRODUCTION',
+ 'SQLALCHEMY_DATABASE_URI_SCHEME': 'postgresql',
'SQLALCHEMY_POOL_SIZE': 30,
'SQLALCHEMY_MAX_OVERFLOW': 10,
'SQLALCHEMY_POOL_RECYCLE': 3600,
+ 'SQLALCHEMY_POOL_PRE_PING': 'False',
'SQLALCHEMY_ECHO': 'False',
'SENTRY_TRACES_SAMPLE_RATE': 1.0,
'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
@@ -67,6 +70,7 @@ DEFAULTS = {
'INVITE_EXPIRY_HOURS': 72,
'BILLING_ENABLED': 'False',
'CAN_REPLACE_LOGO': 'False',
+ 'MODEL_LB_ENABLED': 'False',
'ETL_TYPE': 'dify',
'KEYWORD_STORE': 'jieba',
'BATCH_UPLOAD_LIMIT': 20,
@@ -77,6 +81,10 @@ DEFAULTS = {
'KEYWORD_DATA_SOURCE_TYPE': 'database',
'INNER_API': 'False',
'ENTERPRISE_ENABLED': 'False',
+ 'INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH': 1000,
+ 'WORKFLOW_MAX_EXECUTION_STEPS': 500,
+ 'WORKFLOW_MAX_EXECUTION_TIME': 1200,
+ 'WORKFLOW_CALL_MAX_DEPTH': 5,
}
@@ -107,7 +115,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
- self.CURRENT_VERSION = "0.6.6"
+ self.CURRENT_VERSION = "0.6.10"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = get_env('EDITION')
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -116,6 +124,7 @@ class Config:
self.LOG_FILE = get_env('LOG_FILE')
self.LOG_FORMAT = get_env('LOG_FORMAT')
self.LOG_DATEFORMAT = get_env('LOG_DATEFORMAT')
+ self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
# The backend URL prefix of the console API.
# used to concatenate the login authorization callback or notion integration callback.
@@ -138,6 +147,10 @@ class Config:
# Url is signed and has expiration time.
self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL
+ # File Access Time specifies a time interval in seconds for the file to be accessed.
+ # The default value is 300 seconds.
+ self.FILES_ACCESS_TIMEOUT = int(get_env('FILES_ACCESS_TIMEOUT'))
+
# Your App secret key will be used for securely signing the session cookie
# Make sure you are changing this key for your deployment with a strong key.
# You can generate a strong key using `openssl rand -base64 42`.
@@ -165,14 +178,17 @@ class Config:
key: get_env(key) for key in
['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE', 'DB_CHARSET']
}
+ self.SQLALCHEMY_DATABASE_URI_SCHEME = get_env('SQLALCHEMY_DATABASE_URI_SCHEME')
db_extras = f"?client_encoding={db_credentials['DB_CHARSET']}" if db_credentials['DB_CHARSET'] else ""
- self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}"
+ self.SQLALCHEMY_DATABASE_URI = f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}"
self.SQLALCHEMY_ENGINE_OPTIONS = {
'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
'max_overflow': int(get_env('SQLALCHEMY_MAX_OVERFLOW')),
- 'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
+ 'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')),
+ 'pool_pre_ping': get_bool_env('SQLALCHEMY_POOL_PRE_PING'),
+ 'connect_args': {'options': '-c timezone=UTC'},
}
self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
@@ -196,36 +212,51 @@ class Config:
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
+ # ------------------------
+ # Code Execution Sandbox Configurations.
+ # ------------------------
+ self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
+ self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
+
# ------------------------
# File Storage Configurations.
# ------------------------
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
+
+ # S3 Storage settings
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
self.S3_REGION = get_env('S3_REGION')
self.S3_ADDRESS_STYLE = get_env('S3_ADDRESS_STYLE')
+
+ # Azure Blob Storage settings
self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME')
self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
- self.ALIYUN_OSS_BUCKET_NAME=get_env('ALIYUN_OSS_BUCKET_NAME')
- self.ALIYUN_OSS_ACCESS_KEY=get_env('ALIYUN_OSS_ACCESS_KEY')
- self.ALIYUN_OSS_SECRET_KEY=get_env('ALIYUN_OSS_SECRET_KEY')
- self.ALIYUN_OSS_ENDPOINT=get_env('ALIYUN_OSS_ENDPOINT')
- self.ALIYUN_OSS_REGION=get_env('ALIYUN_OSS_REGION')
- self.ALIYUN_OSS_AUTH_VERSION=get_env('ALIYUN_OSS_AUTH_VERSION')
+
+ # Aliyun Storage settings
+ self.ALIYUN_OSS_BUCKET_NAME = get_env('ALIYUN_OSS_BUCKET_NAME')
+ self.ALIYUN_OSS_ACCESS_KEY = get_env('ALIYUN_OSS_ACCESS_KEY')
+ self.ALIYUN_OSS_SECRET_KEY = get_env('ALIYUN_OSS_SECRET_KEY')
+ self.ALIYUN_OSS_ENDPOINT = get_env('ALIYUN_OSS_ENDPOINT')
+ self.ALIYUN_OSS_REGION = get_env('ALIYUN_OSS_REGION')
+ self.ALIYUN_OSS_AUTH_VERSION = get_env('ALIYUN_OSS_AUTH_VERSION')
+
+ # Google Cloud Storage settings
self.GOOGLE_STORAGE_BUCKET_NAME = get_env('GOOGLE_STORAGE_BUCKET_NAME')
self.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 = get_env('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')
# ------------------------
# Vector Store Configurations.
- # Currently, only support: qdrant, milvus, zilliz, weaviate, relyt
+ # Currently, only support: qdrant, milvus, zilliz, weaviate, relyt, pgvector
# ------------------------
self.VECTOR_STORE = get_env('VECTOR_STORE')
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
+
# qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
@@ -261,6 +292,20 @@ class Config:
self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD')
self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE')
+ # pgvector settings
+ self.PGVECTOR_HOST = get_env('PGVECTOR_HOST')
+ self.PGVECTOR_PORT = get_env('PGVECTOR_PORT')
+ self.PGVECTOR_USER = get_env('PGVECTOR_USER')
+ self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
+ self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')
+
+ # tidb-vector settings
+ self.TIDB_VECTOR_HOST = get_env('TIDB_VECTOR_HOST')
+ self.TIDB_VECTOR_PORT = get_env('TIDB_VECTOR_PORT')
+ self.TIDB_VECTOR_USER = get_env('TIDB_VECTOR_USER')
+ self.TIDB_VECTOR_PASSWORD = get_env('TIDB_VECTOR_PASSWORD')
+ self.TIDB_VECTOR_DATABASE = get_env('TIDB_VECTOR_DATABASE')
+
# ------------------------
# Mail Configurations.
# ------------------------
@@ -274,7 +319,8 @@ class Config:
self.SMTP_USERNAME = get_env('SMTP_USERNAME')
self.SMTP_PASSWORD = get_env('SMTP_PASSWORD')
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
-
+ self.SMTP_OPPORTUNISTIC_TLS = get_bool_env('SMTP_OPPORTUNISTIC_TLS')
+
# ------------------------
# Workspace Configurations.
# ------------------------
@@ -301,6 +347,23 @@ class Config:
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
+ self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
+
+ # RAG ETL Configurations.
+ self.ETL_TYPE = get_env('ETL_TYPE')
+ self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
+ self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY')
+ self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
+
+ # Indexing Configurations.
+ self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH')
+
+ # Tool Configurations.
+ self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
+
+ self.WORKFLOW_MAX_EXECUTION_STEPS = int(get_env('WORKFLOW_MAX_EXECUTION_STEPS'))
+ self.WORKFLOW_MAX_EXECUTION_TIME = int(get_env('WORKFLOW_MAX_EXECUTION_TIME'))
+ self.WORKFLOW_CALL_MAX_DEPTH = int(get_env('WORKFLOW_CALL_MAX_DEPTH'))
# Moderation in app Configurations.
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
@@ -352,18 +415,15 @@ class Config:
self.HOSTED_FETCH_APP_TEMPLATES_MODE = get_env('HOSTED_FETCH_APP_TEMPLATES_MODE')
self.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = get_env('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN')
- self.ETL_TYPE = get_env('ETL_TYPE')
- self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
+ # Model Load Balancing Configurations.
+ self.MODEL_LB_ENABLED = get_bool_env('MODEL_LB_ENABLED')
+
+ # Platform Billing Configurations.
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
- self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
- self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
-
- self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
- self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
-
- self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
- self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
-
- self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
+ # ------------------------
+ # Enterprise feature Configurations.
+ # **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
+ # ------------------------
self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
+ self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
diff --git a/api/constants/languages.py b/api/constants/languages.py
index bdfd8022a3..b4626cf51f 100644
--- a/api/constants/languages.py
+++ b/api/constants/languages.py
@@ -1,6 +1,6 @@
-languages = ['en-US', 'zh-Hans', 'zh-Hant', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN']
+languages = ['en-US', 'zh-Hans', 'zh-Hant', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN', 'pl-PL']
language_timezone_mapping = {
'en-US': 'America/New_York',
@@ -16,6 +16,8 @@ language_timezone_mapping = {
'it-IT': 'Europe/Rome',
'uk-UA': 'Europe/Kyiv',
'vi-VN': 'Asia/Ho_Chi_Minh',
+ 'ro-RO': 'Europe/Bucharest',
+ 'pl-PL': 'Europe/Warsaw',
}
diff --git a/api/constants/recommended_apps.json b/api/constants/recommended_apps.json
index 8a1ee808e4..68c913f80a 100644
--- a/api/constants/recommended_apps.json
+++ b/api/constants/recommended_apps.json
@@ -24,7 +24,8 @@
"description": "Welcome to your personalized Investment Analysis Copilot service, where we delve into the depths of stock analysis to provide you with comprehensive insights. \n",
"is_listed": true,
"position": 0,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -40,7 +41,8 @@
"description": "Code interpreter, clarifying the syntax and semantics of the code.",
"is_listed": true,
"position": 13,
- "privacy_policy": "https://dify.ai"
+ "privacy_policy": "https://dify.ai",
+ "custom_disclaimer": null
},
{
"app": {
@@ -56,7 +58,8 @@
"description": "Hello, I am your creative partner in bringing ideas to vivid life! I can assist you in creating stunning designs by leveraging abilities of DALL\u00b7E 3. ",
"is_listed": true,
"position": 4,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -72,7 +75,8 @@
"description": "Fully SEO Optimized Article including FAQs",
"is_listed": true,
"position": 1,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -88,7 +92,8 @@
"description": "Generate Flat Style Image",
"is_listed": true,
"position": 10,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -104,7 +109,8 @@
"description": "A multilingual translator that provides translation capabilities in multiple languages. Input the text you need to translate and select the target language.",
"is_listed": true,
"position": 10,
- "privacy_policy": "https://dify.ai"
+ "privacy_policy": "https://dify.ai",
+ "custom_disclaimer": null
},
{
"app": {
@@ -120,7 +126,8 @@
"description": "I am a YouTube Channel Data Analysis Copilot, I am here to provide expert data analysis tailored to your needs. ",
"is_listed": true,
"position": 2,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -136,7 +143,8 @@
"description": "Meeting minutes generator",
"is_listed": true,
"position": 0,
- "privacy_policy": "https://dify.ai"
+ "privacy_policy": "https://dify.ai",
+ "custom_disclaimer": null
},
{
"app": {
@@ -152,7 +160,8 @@
"description": "Tell me the main elements, I will generate a cyberpunk style image for you. ",
"is_listed": true,
"position": 10,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -168,7 +177,8 @@
"description": "Write SQL from natural language by pasting in your schema with the request.Please describe your query requirements in natural language and select the target database type.",
"is_listed": true,
"position": 13,
- "privacy_policy": "https://dify.ai"
+ "privacy_policy": "https://dify.ai",
+ "custom_disclaimer": null
},
{
"app": {
@@ -184,7 +194,8 @@
"description": "Welcome to your personalized travel service with Consultant! \ud83c\udf0d\u2708\ufe0f Ready to embark on a journey filled with adventure and relaxation? Let's dive into creating your unforgettable travel experience. ",
"is_listed": true,
"position": 3,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -200,7 +211,8 @@
"description": "I can answer your questions related to strategic marketing.",
"is_listed": true,
"position": 10,
- "privacy_policy": "https://dify.ai"
+ "privacy_policy": "https://dify.ai",
+ "custom_disclaimer": null
},
{
"app": {
@@ -216,7 +228,8 @@
"description": "A simulated front-end interviewer that tests the skill level of front-end development through questioning.",
"is_listed": true,
"position": 19,
- "privacy_policy": "https://dify.ai"
+ "privacy_policy": "https://dify.ai",
+ "custom_disclaimer": null
},
{
"app": {
@@ -232,7 +245,8 @@
"description": "I'm here to hear about your feature request about Dify and help you flesh it out further. What's on your mind?",
"is_listed": true,
"position": 6,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
}
]
},
@@ -261,7 +275,8 @@
"description": "\u4e00\u4e2a\u6a21\u62df\u7684\u524d\u7aef\u9762\u8bd5\u5b98\uff0c\u901a\u8fc7\u63d0\u95ee\u7684\u65b9\u5f0f\u5bf9\u524d\u7aef\u5f00\u53d1\u7684\u6280\u80fd\u6c34\u5e73\u8fdb\u884c\u68c0\u9a8c\u3002",
"is_listed": true,
"position": 20,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -277,7 +292,8 @@
"description": "\u8f93\u5165\u76f8\u5173\u5143\u7d20\uff0c\u4e3a\u4f60\u751f\u6210\u6241\u5e73\u63d2\u753b\u98ce\u683c\u7684\u5c01\u9762\u56fe\u7247",
"is_listed": true,
"position": 10,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -293,7 +309,8 @@
"description": "\u4e00\u4e2a\u591a\u8bed\u8a00\u7ffb\u8bd1\u5668\uff0c\u63d0\u4f9b\u591a\u79cd\u8bed\u8a00\u7ffb\u8bd1\u80fd\u529b\uff0c\u8f93\u5165\u4f60\u9700\u8981\u7ffb\u8bd1\u7684\u6587\u672c\uff0c\u9009\u62e9\u76ee\u6807\u8bed\u8a00\u5373\u53ef\u3002\u63d0\u793a\u8bcd\u6765\u81ea\u5b9d\u7389\u3002",
"is_listed": true,
"position": 10,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -309,7 +326,8 @@
"description": "\u6211\u5c06\u5e2e\u52a9\u4f60\u628a\u81ea\u7136\u8bed\u8a00\u8f6c\u5316\u6210\u6307\u5b9a\u7684\u6570\u636e\u5e93\u67e5\u8be2 SQL \u8bed\u53e5\uff0c\u8bf7\u5728\u4e0b\u65b9\u8f93\u5165\u4f60\u9700\u8981\u67e5\u8be2\u7684\u6761\u4ef6\uff0c\u5e76\u9009\u62e9\u76ee\u6807\u6570\u636e\u5e93\u7c7b\u578b\u3002",
"is_listed": true,
"position": 12,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -325,7 +343,8 @@
"description": "\u9610\u660e\u4ee3\u7801\u7684\u8bed\u6cd5\u548c\u8bed\u4e49\u3002",
"is_listed": true,
"position": 2,
- "privacy_policy": "https://dify.ai"
+ "privacy_policy": "https://dify.ai",
+ "custom_disclaimer": null
},
{
"app": {
@@ -341,7 +360,8 @@
"description": "\u8f93\u5165\u76f8\u5173\u5143\u7d20\uff0c\u4e3a\u4f60\u751f\u6210\u8d5b\u535a\u670b\u514b\u98ce\u683c\u7684\u63d2\u753b",
"is_listed": true,
"position": 10,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -357,7 +377,8 @@
"description": "\u6211\u662f\u4e00\u540dSEO\u4e13\u5bb6\uff0c\u53ef\u4ee5\u6839\u636e\u60a8\u63d0\u4f9b\u7684\u6807\u9898\u3001\u5173\u952e\u8bcd\u3001\u76f8\u5173\u4fe1\u606f\u6765\u6279\u91cf\u751f\u6210SEO\u6587\u7ae0\u3002",
"is_listed": true,
"position": 10,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -373,7 +394,8 @@
"description": "\u5e2e\u4f60\u91cd\u65b0\u7ec4\u7ec7\u548c\u8f93\u51fa\u6df7\u4e71\u590d\u6742\u7684\u4f1a\u8bae\u7eaa\u8981\u3002",
"is_listed": true,
"position": 6,
- "privacy_policy": "https://dify.ai"
+ "privacy_policy": "https://dify.ai",
+ "custom_disclaimer": null
},
{
"app": {
@@ -389,7 +411,8 @@
"description": "\u6b22\u8fce\u4f7f\u7528\u60a8\u7684\u4e2a\u6027\u5316\u7f8e\u80a1\u6295\u8d44\u5206\u6790\u52a9\u624b\uff0c\u5728\u8fd9\u91cc\u6211\u4eec\u6df1\u5165\u7684\u8fdb\u884c\u80a1\u7968\u5206\u6790\uff0c\u4e3a\u60a8\u63d0\u4f9b\u5168\u9762\u7684\u6d1e\u5bdf\u3002",
"is_listed": true,
"position": 0,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -405,7 +428,8 @@
"description": "\u60a8\u597d\uff0c\u6211\u662f\u60a8\u7684\u521b\u610f\u4f19\u4f34\uff0c\u5c06\u5e2e\u52a9\u60a8\u5c06\u60f3\u6cd5\u751f\u52a8\u5730\u5b9e\u73b0\uff01\u6211\u53ef\u4ee5\u534f\u52a9\u60a8\u5229\u7528DALL\u00b7E 3\u7684\u80fd\u529b\u521b\u9020\u51fa\u4ee4\u4eba\u60ca\u53f9\u7684\u8bbe\u8ba1\u3002",
"is_listed": true,
"position": 4,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -421,7 +445,8 @@
"description": "\u7ffb\u8bd1\u4e13\u5bb6\uff1a\u63d0\u4f9b\u4e2d\u82f1\u6587\u4e92\u8bd1",
"is_listed": true,
"position": 4,
- "privacy_policy": "https://dify.ai"
+ "privacy_policy": "https://dify.ai",
+ "custom_disclaimer": null
},
{
"app": {
@@ -437,7 +462,8 @@
"description": "\u60a8\u7684\u79c1\u4eba\u5b66\u4e60\u5bfc\u5e08\uff0c\u5e2e\u60a8\u5236\u5b9a\u5b66\u4e60\u8ba1\u5212\u5e76\u8f85\u5bfc",
"is_listed": true,
"position": 26,
- "privacy_policy": "https://dify.ai"
+ "privacy_policy": "https://dify.ai",
+ "custom_disclaimer": null
},
{
"app": {
@@ -453,7 +479,8 @@
"description": "\u5e2e\u4f60\u64b0\u5199\u8bba\u6587\u6587\u732e\u7efc\u8ff0",
"is_listed": true,
"position": 7,
- "privacy_policy": "https://dify.ai"
+ "privacy_policy": "https://dify.ai",
+ "custom_disclaimer": null
},
{
"app": {
@@ -469,7 +496,8 @@
"description": "\u4f60\u597d\uff0c\u544a\u8bc9\u6211\u60a8\u60f3\u5206\u6790\u7684 YouTube \u9891\u9053\uff0c\u6211\u5c06\u4e3a\u60a8\u6574\u7406\u4e00\u4efd\u5b8c\u6574\u7684\u6570\u636e\u5206\u6790\u62a5\u544a\u3002",
"is_listed": true,
"position": 0,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
},
{
"app": {
@@ -485,7 +513,8 @@
"description": "\u6b22\u8fce\u4f7f\u7528\u60a8\u7684\u4e2a\u6027\u5316\u65c5\u884c\u670d\u52a1\u987e\u95ee\uff01\ud83c\udf0d\u2708\ufe0f \u51c6\u5907\u597d\u8e0f\u4e0a\u4e00\u6bb5\u5145\u6ee1\u5192\u9669\u4e0e\u653e\u677e\u7684\u65c5\u7a0b\u4e86\u5417\uff1f\u8ba9\u6211\u4eec\u4e00\u8d77\u6df1\u5165\u6253\u9020\u60a8\u96be\u5fd8\u7684\u65c5\u884c\u4f53\u9a8c\u5427\u3002",
"is_listed": true,
"position": 0,
- "privacy_policy": null
+ "privacy_policy": null,
+ "custom_disclaimer": null
}
]
},
diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py
index 498557cd51..306b7384cf 100644
--- a/api/controllers/console/__init__.py
+++ b/api/controllers/console/__init__.py
@@ -37,9 +37,6 @@ from .billing import billing
# Import datasets controllers
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
-# Import enterprise controllers
-from .enterprise import enterprise_sso
-
# Import explore controllers
from .explore import (
audio,
@@ -57,4 +54,4 @@ from .explore import (
from .tag import tags
# Import workspace controllers
-from .workspace import account, members, model_providers, models, tool_providers, workspace
+from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace
diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py
index aaa737f83a..028be5de54 100644
--- a/api/controllers/console/admin.py
+++ b/api/controllers/console/admin.py
@@ -48,6 +48,7 @@ class InsertExploreAppListApi(Resource):
parser.add_argument('desc', type=str, location='json')
parser.add_argument('copyright', type=str, location='json')
parser.add_argument('privacy_policy', type=str, location='json')
+ parser.add_argument('custom_disclaimer', type=str, location='json')
parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json')
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
parser.add_argument('position', type=int, required=True, nullable=False, location='json')
@@ -62,6 +63,7 @@ class InsertExploreAppListApi(Resource):
desc = args['desc'] if args['desc'] else ''
copy_right = args['copyright'] if args['copyright'] else ''
privacy_policy = args['privacy_policy'] if args['privacy_policy'] else ''
+ custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else ''
else:
desc = site.description if site.description else \
args['desc'] if args['desc'] else ''
@@ -69,6 +71,8 @@ class InsertExploreAppListApi(Resource):
args['copyright'] if args['copyright'] else ''
privacy_policy = site.privacy_policy if site.privacy_policy else \
args['privacy_policy'] if args['privacy_policy'] else ''
+ custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \
+ args['custom_disclaimer'] if args['custom_disclaimer'] else ''
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
@@ -78,6 +82,7 @@ class InsertExploreAppListApi(Resource):
description=desc,
copyright=copy_right,
privacy_policy=privacy_policy,
+ custom_disclaimer=custom_disclaimer,
language=args['language'],
category=args['category'],
position=args['position']
@@ -93,6 +98,7 @@ class InsertExploreAppListApi(Resource):
recommended_app.description = desc
recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy
+ recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args['language']
recommended_app.category = args['category']
recommended_app.position = args['position']
diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py
index 29d89ae460..51322c92d3 100644
--- a/api/controllers/console/app/audio.py
+++ b/api/controllers/console/app/audio.py
@@ -85,7 +85,7 @@ class ChatMessageTextApi(Resource):
response = AudioService.transcript_tts(
app_model=app_model,
text=request.form['text'],
- voice=request.form.get('voice'),
+ voice=request.form['voice'],
streaming=False
)
diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py
index b1abb38248..fbe42fbd2a 100644
--- a/api/controllers/console/app/error.py
+++ b/api/controllers/console/app/error.py
@@ -91,3 +91,9 @@ class DraftWorkflowNotExist(BaseHTTPException):
error_code = 'draft_workflow_not_exist'
description = "Draft workflow need to be initialized."
code = 400
+
+
+class DraftWorkflowNotSync(BaseHTTPException):
+ error_code = 'draft_workflow_not_sync'
+ description = "Workflow graph might have been modified, please refresh and resubmit."
+ code = 400
diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py
index 256824981e..592009fd88 100644
--- a/api/controllers/console/app/site.py
+++ b/api/controllers/console/app/site.py
@@ -23,6 +23,7 @@ def parse_app_site_args():
parser.add_argument('customize_domain', type=str, required=False, location='json')
parser.add_argument('copyright', type=str, required=False, location='json')
parser.add_argument('privacy_policy', type=str, required=False, location='json')
+ parser.add_argument('custom_disclaimer', type=str, required=False, location='json')
parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'],
required=False,
location='json')
@@ -56,6 +57,7 @@ class AppSite(Resource):
'customize_domain',
'copyright',
'privacy_policy',
+ 'custom_disclaimer',
'customize_token_strategy',
'prompt_public'
]:
diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py
index b88a9b7fcc..641997f3f3 100644
--- a/api/controllers/console/app/workflow.py
+++ b/api/controllers/console/app/workflow.py
@@ -7,7 +7,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.console import api
-from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist
+from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
@@ -20,6 +20,7 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models.model import App, AppMode
from services.app_generate_service import AppGenerateService
+from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
@@ -59,6 +60,7 @@ class DraftWorkflowApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
+ parser.add_argument('hash', type=str, required=False, location='json')
args = parser.parse_args()
elif 'text/plain' in content_type:
try:
@@ -71,7 +73,8 @@ class DraftWorkflowApi(Resource):
args = {
'graph': data.get('graph'),
- 'features': data.get('features')
+ 'features': data.get('features'),
+ 'hash': data.get('hash')
}
except json.JSONDecodeError:
return {'message': 'Invalid JSON data'}, 400
@@ -79,15 +82,21 @@ class DraftWorkflowApi(Resource):
abort(415)
workflow_service = WorkflowService()
- workflow = workflow_service.sync_draft_workflow(
- app_model=app_model,
- graph=args.get('graph'),
- features=args.get('features'),
- account=current_user
- )
+
+ try:
+ workflow = workflow_service.sync_draft_workflow(
+ app_model=app_model,
+ graph=args.get('graph'),
+ features=args.get('features'),
+ unique_hash=args.get('hash'),
+ account=current_user
+ )
+ except WorkflowHashNotEqualError:
+ raise DraftWorkflowNotSync()
return {
"result": "success",
+ "hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
}
@@ -128,6 +137,71 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
logging.exception("internal server error.")
raise InternalServerError()
+class AdvancedChatDraftRunIterationNodeApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model(mode=[AppMode.ADVANCED_CHAT])
+ def post(self, app_model: App, node_id: str):
+ """
+ Run draft workflow iteration node
+ """
+ parser = reqparse.RequestParser()
+ parser.add_argument('inputs', type=dict, location='json')
+ args = parser.parse_args()
+
+ try:
+ response = AppGenerateService.generate_single_iteration(
+ app_model=app_model,
+ user=current_user,
+ node_id=node_id,
+ args=args,
+ streaming=True
+ )
+
+ return helper.compact_generate_response(response)
+ except services.errors.conversation.ConversationNotExistsError:
+ raise NotFound("Conversation Not Exists.")
+ except services.errors.conversation.ConversationCompletedError:
+ raise ConversationCompletedError()
+ except ValueError as e:
+ raise e
+ except Exception as e:
+ logging.exception("internal server error.")
+ raise InternalServerError()
+
+class WorkflowDraftRunIterationNodeApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model(mode=[AppMode.WORKFLOW])
+ def post(self, app_model: App, node_id: str):
+ """
+ Run draft workflow iteration node
+ """
+ parser = reqparse.RequestParser()
+ parser.add_argument('inputs', type=dict, location='json')
+ args = parser.parse_args()
+
+ try:
+ response = AppGenerateService.generate_single_iteration(
+ app_model=app_model,
+ user=current_user,
+ node_id=node_id,
+ args=args,
+ streaming=True
+ )
+
+ return helper.compact_generate_response(response)
+ except services.errors.conversation.ConversationNotExistsError:
+ raise NotFound("Conversation Not Exists.")
+ except services.errors.conversation.ConversationCompletedError:
+ raise ConversationCompletedError()
+ except ValueError as e:
+ raise e
+ except Exception as e:
+ logging.exception("internal server error.")
+ raise InternalServerError()
class DraftWorkflowRunApi(Resource):
@setup_required
@@ -317,6 +391,8 @@ api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-
api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run')
api.add_resource(WorkflowTaskStopApi, '/apps//workflow-runs/tasks//stop')
api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run')
+api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps//advanced-chat/workflows/draft/iteration/nodes//run')
+api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps//workflows/draft/iteration/nodes//run')
api.add_resource(PublishedWorkflowApi, '/apps//workflows/publish')
api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs')
api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs'
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index 40ded54120..72c4c09055 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -476,13 +476,13 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required
def get(self):
vector_type = current_app.config['VECTOR_STORE']
- if vector_type == 'milvus' or vector_type == 'pgvecto_rs' or vector_type == 'relyt':
+ if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs", 'tidb_vector'}:
return {
'retrieval_method': [
'semantic_search'
]
}
- elif vector_type == 'qdrant' or vector_type == 'weaviate':
+ elif vector_type in {"qdrant", "weaviate"}:
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
@@ -497,14 +497,13 @@ class DatasetRetrievalSettingMockApi(Resource):
@login_required
@account_initialization_required
def get(self, vector_type):
-
- if vector_type == 'milvus' or vector_type == 'relyt':
+ if vector_type in {'milvus', 'relyt', 'pgvector', 'tidb_vector'}:
return {
'retrieval_method': [
'semantic_search'
]
}
- elif vector_type == 'qdrant' or vector_type == 'weaviate':
+ elif vector_type in {'qdrant', 'weaviate'}:
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index 9dedcefe0f..5498d22e78 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -1,10 +1,12 @@
import logging
+from argparse import ArgumentTypeError
from datetime import datetime, timezone
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc
+from transformers.hf_argparser import string_to_bool
from werkzeug.exceptions import Forbidden, NotFound
import services
@@ -141,7 +143,11 @@ class DatasetDocumentListApi(Resource):
limit = request.args.get('limit', default=20, type=int)
search = request.args.get('keyword', default=None, type=str)
sort = request.args.get('sort', default='-created_at', type=str)
- fetch = request.args.get('fetch', default=False, type=bool)
+ # "yes", "true", "t", "y", "1" convert to True, while others convert to False.
+ try:
+ fetch = string_to_bool(request.args.get('fetch', default='false'))
+ except (ArgumentTypeError, ValueError, Exception) as e:
+ fetch = False
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
@@ -924,6 +930,28 @@ class DocumentRetryApi(DocumentResource):
return {'result': 'success'}, 204
+class DocumentRenameApi(DocumentResource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @marshal_with(document_fields)
+ def post(self, dataset_id, document_id):
+ # The role of the current user in the ta table must be admin or owner
+ if not current_user.is_admin_or_owner:
+ raise Forbidden()
+
+ parser = reqparse.RequestParser()
+ parser.add_argument('name', type=str, required=True, nullable=False, location='json')
+ args = parser.parse_args()
+
+ try:
+ document = DocumentService.rename_document(dataset_id, document_id, args['name'])
+ except services.errors.document.DocumentIndexingError:
+ raise DocumentIndexingError('Cannot delete document during indexing.')
+
+ return document
+
+
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi,
'/datasets//documents')
@@ -950,3 +978,5 @@ api.add_resource(DocumentStatusApi,
api.add_resource(DocumentPauseApi, '/datasets//documents//processing/pause')
api.add_resource(DocumentRecoverApi, '/datasets//documents//processing/resume')
api.add_resource(DocumentRetryApi, '/datasets//retry')
+api.add_resource(DocumentRenameApi,
+ '/datasets//documents//rename')
diff --git a/api/controllers/console/enterprise/enterprise_sso.py b/api/controllers/console/enterprise/enterprise_sso.py
deleted file mode 100644
index f6a2897d5a..0000000000
--- a/api/controllers/console/enterprise/enterprise_sso.py
+++ /dev/null
@@ -1,59 +0,0 @@
-from flask import current_app, redirect
-from flask_restful import Resource, reqparse
-
-from controllers.console import api
-from controllers.console.setup import setup_required
-from services.enterprise.enterprise_sso_service import EnterpriseSSOService
-
-
-class EnterpriseSSOSamlLogin(Resource):
-
- @setup_required
- def get(self):
- return EnterpriseSSOService.get_sso_saml_login()
-
-
-class EnterpriseSSOSamlAcs(Resource):
-
- @setup_required
- def post(self):
- parser = reqparse.RequestParser()
- parser.add_argument('SAMLResponse', type=str, required=True, location='form')
- args = parser.parse_args()
- saml_response = args['SAMLResponse']
-
- try:
- token = EnterpriseSSOService.post_sso_saml_acs(saml_response)
- return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
- except Exception as e:
- return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
-
-
-class EnterpriseSSOOidcLogin(Resource):
-
- @setup_required
- def get(self):
- return EnterpriseSSOService.get_sso_oidc_login()
-
-
-class EnterpriseSSOOidcCallback(Resource):
-
- @setup_required
- def get(self):
- parser = reqparse.RequestParser()
- parser.add_argument('state', type=str, required=True, location='args')
- parser.add_argument('code', type=str, required=True, location='args')
- parser.add_argument('oidc-state', type=str, required=True, location='cookies')
- args = parser.parse_args()
-
- try:
- token = EnterpriseSSOService.get_sso_oidc_callback(args)
- return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
- except Exception as e:
- return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
-
-
-api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login')
-api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs')
-api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login')
-api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback')
diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py
index f03663f1a2..d869cd38ed 100644
--- a/api/controllers/console/explore/audio.py
+++ b/api/controllers/console/explore/audio.py
@@ -76,7 +76,7 @@ class ChatTextApi(InstalledAppResource):
response = AudioService.transcript_tts(
app_model=app_model,
text=request.form['text'],
- voice=request.form.get('voice'),
+ voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)
return {'data': response.data.decode('latin1')}
diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py
index 2787b7cdba..6e10e2ec92 100644
--- a/api/controllers/console/explore/recommended_app.py
+++ b/api/controllers/console/explore/recommended_app.py
@@ -21,6 +21,7 @@ recommended_app_fields = {
'description': fields.String(attribute='description'),
'copyright': fields.String,
'privacy_policy': fields.String,
+ 'custom_disclaimer': fields.String,
'category': fields.String,
'position': fields.Integer,
'is_listed': fields.Boolean
diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py
index 325652a447..44d9d67522 100644
--- a/api/controllers/console/feature.py
+++ b/api/controllers/console/feature.py
@@ -1,24 +1,28 @@
from flask_login import current_user
from flask_restful import Resource
-from services.enterprise.enterprise_feature_service import EnterpriseFeatureService
+from libs.login import login_required
from services.feature_service import FeatureService
from . import api
-from .wraps import cloud_utm_record
+from .setup import setup_required
+from .wraps import account_initialization_required, cloud_utm_record
class FeatureApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
@cloud_utm_record
def get(self):
return FeatureService.get_features(current_user.current_tenant_id).dict()
-class EnterpriseFeatureApi(Resource):
+class SystemFeatureApi(Resource):
def get(self):
- return EnterpriseFeatureService.get_enterprise_features().dict()
+ return FeatureService.get_system_features().dict()
api.add_resource(FeatureApi, '/features')
-api.add_resource(EnterpriseFeatureApi, '/enterprise-features')
+api.add_resource(SystemFeatureApi, '/system-features')
diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py
index a50e4c41a8..faf36c4f40 100644
--- a/api/controllers/console/version.py
+++ b/api/controllers/console/version.py
@@ -17,13 +17,19 @@ class VersionApi(Resource):
args = parser.parse_args()
check_update_url = current_app.config['CHECK_UPDATE_URL']
- if not check_update_url:
- return {
- 'version': '0.0.0',
- 'release_date': '',
- 'release_notes': '',
- 'can_auto_update': False
+ result = {
+ 'version': current_app.config['CURRENT_VERSION'],
+ 'release_date': '',
+ 'release_notes': '',
+ 'can_auto_update': False,
+ 'features': {
+ 'can_replace_logo': current_app.config['CAN_REPLACE_LOGO'],
+ 'model_load_balancing_enabled': current_app.config['MODEL_LB_ENABLED']
}
+ }
+
+ if not check_update_url:
+ return result
try:
response = requests.get(check_update_url, {
@@ -31,20 +37,15 @@ class VersionApi(Resource):
})
except Exception as error:
logging.warning("Check update version error: {}.".format(str(error)))
- return {
- 'version': args.get('current_version'),
- 'release_date': '',
- 'release_notes': '',
- 'can_auto_update': False
- }
+ result['version'] = args.get('current_version')
+ return result
content = json.loads(response.content)
- return {
- 'version': content['version'],
- 'release_date': content['releaseDate'],
- 'release_notes': content['releaseNotes'],
- 'can_auto_update': content['canAutoUpdate']
- }
+ result['version'] = content['version']
+ result['release_date'] = content['releaseDate']
+ result['release_notes'] = content['releaseNotes']
+ result['can_auto_update'] = content['canAutoUpdate']
+ return result
api.add_resource(VersionApi, '/version')
diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py
new file mode 100644
index 0000000000..50514e39f6
--- /dev/null
+++ b/api/controllers/console/workspace/load_balancing_config.py
@@ -0,0 +1,106 @@
+from flask_restful import Resource, reqparse
+from werkzeug.exceptions import Forbidden
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from libs.login import current_user, login_required
+from models.account import TenantAccountRole
+from services.model_load_balancing_service import ModelLoadBalancingService
+
+
+class LoadBalancingCredentialsValidateApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self, provider: str):
+ if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
+ raise Forbidden()
+
+ tenant_id = current_user.current_tenant_id
+
+ parser = reqparse.RequestParser()
+ parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+ parser.add_argument('model_type', type=str, required=True, nullable=False,
+ choices=[mt.value for mt in ModelType], location='json')
+ parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+ args = parser.parse_args()
+
+ # validate model load balancing credentials
+ model_load_balancing_service = ModelLoadBalancingService()
+
+ result = True
+ error = None
+
+ try:
+ model_load_balancing_service.validate_load_balancing_credentials(
+ tenant_id=tenant_id,
+ provider=provider,
+ model=args['model'],
+ model_type=args['model_type'],
+ credentials=args['credentials']
+ )
+ except CredentialsValidateFailedError as ex:
+ result = False
+ error = str(ex)
+
+ response = {'result': 'success' if result else 'error'}
+
+ if not result:
+ response['error'] = error
+
+ return response
+
+
+class LoadBalancingConfigCredentialsValidateApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self, provider: str, config_id: str):
+ if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
+ raise Forbidden()
+
+ tenant_id = current_user.current_tenant_id
+
+ parser = reqparse.RequestParser()
+ parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+ parser.add_argument('model_type', type=str, required=True, nullable=False,
+ choices=[mt.value for mt in ModelType], location='json')
+ parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+ args = parser.parse_args()
+
+ # validate model load balancing config credentials
+ model_load_balancing_service = ModelLoadBalancingService()
+
+ result = True
+ error = None
+
+ try:
+ model_load_balancing_service.validate_load_balancing_credentials(
+ tenant_id=tenant_id,
+ provider=provider,
+ model=args['model'],
+ model_type=args['model_type'],
+ credentials=args['credentials'],
+ config_id=config_id,
+ )
+ except CredentialsValidateFailedError as ex:
+ result = False
+ error = str(ex)
+
+ response = {'result': 'success' if result else 'error'}
+
+ if not result:
+ response['error'] = error
+
+ return response
+
+
+# Load Balancing Config
+api.add_resource(LoadBalancingCredentialsValidateApi,
+ '/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate')
+
+api.add_resource(LoadBalancingConfigCredentialsValidateApi,
+ '/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate')
diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py
index 23239b1902..76ae6a4ab9 100644
--- a/api/controllers/console/workspace/models.py
+++ b/api/controllers/console/workspace/models.py
@@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import TenantAccountRole
+from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@@ -104,21 +105,56 @@ class ModelProviderModelApi(Resource):
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
- parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+ parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json')
+ parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json')
+ parser.add_argument('config_from', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()
- model_provider_service = ModelProviderService()
+ model_load_balancing_service = ModelLoadBalancingService()
- try:
- model_provider_service.save_model_credentials(
+ if ('load_balancing' in args and args['load_balancing'] and
+ 'enabled' in args['load_balancing'] and args['load_balancing']['enabled']):
+ if 'configs' not in args['load_balancing']:
+ raise ValueError('invalid load balancing configs')
+
+ # save load balancing configs
+ model_load_balancing_service.update_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
- credentials=args['credentials']
+ configs=args['load_balancing']['configs']
)
- except CredentialsValidateFailedError as ex:
- raise ValueError(str(ex))
+
+ # enable load balancing
+ model_load_balancing_service.enable_model_load_balancing(
+ tenant_id=tenant_id,
+ provider=provider,
+ model=args['model'],
+ model_type=args['model_type']
+ )
+ else:
+ # disable load balancing
+ model_load_balancing_service.disable_model_load_balancing(
+ tenant_id=tenant_id,
+ provider=provider,
+ model=args['model'],
+ model_type=args['model_type']
+ )
+
+ if args.get('config_from', '') != 'predefined-model':
+ model_provider_service = ModelProviderService()
+
+ try:
+ model_provider_service.save_model_credentials(
+ tenant_id=tenant_id,
+ provider=provider,
+ model=args['model'],
+ model_type=args['model_type'],
+ credentials=args['credentials']
+ )
+ except CredentialsValidateFailedError as ex:
+ raise ValueError(str(ex))
return {'result': 'success'}, 200
@@ -170,11 +206,73 @@ class ModelProviderModelCredentialApi(Resource):
model=args['model']
)
+ model_load_balancing_service = ModelLoadBalancingService()
+ is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
+ tenant_id=tenant_id,
+ provider=provider,
+ model=args['model'],
+ model_type=args['model_type']
+ )
+
return {
- "credentials": credentials
+ "credentials": credentials,
+ "load_balancing": {
+ "enabled": is_load_balancing_enabled,
+ "configs": load_balancing_configs
+ }
}
+class ModelProviderModelEnableApi(Resource):
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def patch(self, provider: str):
+ tenant_id = current_user.current_tenant_id
+
+ parser = reqparse.RequestParser()
+ parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+ parser.add_argument('model_type', type=str, required=True, nullable=False,
+ choices=[mt.value for mt in ModelType], location='json')
+ args = parser.parse_args()
+
+ model_provider_service = ModelProviderService()
+ model_provider_service.enable_model(
+ tenant_id=tenant_id,
+ provider=provider,
+ model=args['model'],
+ model_type=args['model_type']
+ )
+
+ return {'result': 'success'}
+
+
+class ModelProviderModelDisableApi(Resource):
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def patch(self, provider: str):
+ tenant_id = current_user.current_tenant_id
+
+ parser = reqparse.RequestParser()
+ parser.add_argument('model', type=str, required=True, nullable=False, location='json')
+ parser.add_argument('model_type', type=str, required=True, nullable=False,
+ choices=[mt.value for mt in ModelType], location='json')
+ args = parser.parse_args()
+
+ model_provider_service = ModelProviderService()
+ model_provider_service.disable_model(
+ tenant_id=tenant_id,
+ provider=provider,
+ model=args['model'],
+ model_type=args['model_type']
+ )
+
+ return {'result': 'success'}
+
+
class ModelProviderModelValidateApi(Resource):
@setup_required
@@ -259,6 +357,10 @@ class ModelProviderAvailableModelApi(Resource):
api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers//models')
+api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers//models/enable',
+ endpoint='model-provider-model-enable')
+api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers//models/disable',
+ endpoint='model-provider-model-disable')
api.add_resource(ModelProviderModelCredentialApi,
'/workspaces/current/model-providers//models/credentials')
api.add_resource(ModelProviderModelValidateApi,
diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py
index b02008339e..a911e9b2cb 100644
--- a/api/controllers/console/workspace/tool_providers.py
+++ b/api/controllers/console/workspace/tool_providers.py
@@ -9,8 +9,13 @@ from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_runtime.utils.encoders import jsonable_encoder
+from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
-from services.tools_manage_service import ToolManageService
+from services.tools.api_tools_manage_service import ApiToolManageService
+from services.tools.builtin_tools_manage_service import BuiltinToolManageService
+from services.tools.tool_labels_service import ToolLabelsService
+from services.tools.tools_manage_service import ToolCommonService
+from services.tools.workflow_tools_manage_service import WorkflowToolManageService
class ToolProviderListApi(Resource):
@@ -21,7 +26,11 @@ class ToolProviderListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
- return ToolManageService.list_tool_providers(user_id, tenant_id)
+ req = reqparse.RequestParser()
+ req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args')
+ args = req.parse_args()
+
+ return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None))
class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@@ -31,7 +40,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
- return jsonable_encoder(ToolManageService.list_builtin_tool_provider_tools(
+ return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools(
user_id,
tenant_id,
provider,
@@ -48,7 +57,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
- return ToolManageService.delete_builtin_tool_provider(
+ return BuiltinToolManageService.delete_builtin_tool_provider(
user_id,
tenant_id,
provider,
@@ -70,7 +79,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
args = parser.parse_args()
- return ToolManageService.update_builtin_tool_provider(
+ return BuiltinToolManageService.update_builtin_tool_provider(
user_id,
tenant_id,
provider,
@@ -85,7 +94,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
- return ToolManageService.get_builtin_tool_provider_credentials(
+ return BuiltinToolManageService.get_builtin_tool_provider_credentials(
user_id,
tenant_id,
provider,
@@ -94,35 +103,10 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider):
- icon_bytes, mimetype = ToolManageService.get_builtin_tool_provider_icon(provider)
+ icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE'))
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
-class ToolModelProviderIconApi(Resource):
- @setup_required
- def get(self, provider):
- icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
- return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)
-
-class ToolModelProviderListToolsApi(Resource):
- @setup_required
- @login_required
- @account_initialization_required
- def get(self):
- user_id = current_user.id
- tenant_id = current_user.current_tenant_id
-
- parser = reqparse.RequestParser()
- parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
-
- args = parser.parse_args()
-
- return jsonable_encoder(ToolManageService.list_model_tool_provider_tools(
- user_id,
- tenant_id,
- args['provider'],
- ))
-
class ToolApiProviderAddApi(Resource):
@setup_required
@login_required
@@ -141,10 +125,12 @@ class ToolApiProviderAddApi(Resource):
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json')
+ parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[])
+ parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()
- return ToolManageService.create_api_tool_provider(
+ return ApiToolManageService.create_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@@ -153,6 +139,8 @@ class ToolApiProviderAddApi(Resource):
args['schema_type'],
args['schema'],
args.get('privacy_policy', ''),
+ args.get('custom_disclaimer', ''),
+ args.get('labels', []),
)
class ToolApiProviderGetRemoteSchemaApi(Resource):
@@ -166,7 +154,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
args = parser.parse_args()
- return ToolManageService.get_api_tool_provider_remote_schema(
+ return ApiToolManageService.get_api_tool_provider_remote_schema(
current_user.id,
current_user.current_tenant_id,
args['url'],
@@ -186,7 +174,7 @@ class ToolApiProviderListToolsApi(Resource):
args = parser.parse_args()
- return jsonable_encoder(ToolManageService.list_api_tool_provider_tools(
+ return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools(
user_id,
tenant_id,
args['provider'],
@@ -211,10 +199,12 @@ class ToolApiProviderUpdateApi(Resource):
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json')
+ parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
+ parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json')
args = parser.parse_args()
- return ToolManageService.update_api_tool_provider(
+ return ApiToolManageService.update_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@@ -224,6 +214,8 @@ class ToolApiProviderUpdateApi(Resource):
args['schema_type'],
args['schema'],
args['privacy_policy'],
+ args['custom_disclaimer'],
+ args.get('labels', []),
)
class ToolApiProviderDeleteApi(Resource):
@@ -243,7 +235,7 @@ class ToolApiProviderDeleteApi(Resource):
args = parser.parse_args()
- return ToolManageService.delete_api_tool_provider(
+ return ApiToolManageService.delete_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@@ -263,7 +255,7 @@ class ToolApiProviderGetApi(Resource):
args = parser.parse_args()
- return ToolManageService.get_api_tool_provider(
+ return ApiToolManageService.get_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@@ -274,7 +266,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
- return ToolManageService.list_builtin_provider_credentials_schema(provider)
+ return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
class ToolApiProviderSchemaApi(Resource):
@setup_required
@@ -287,7 +279,7 @@ class ToolApiProviderSchemaApi(Resource):
args = parser.parse_args()
- return ToolManageService.parser_api_schema(
+ return ApiToolManageService.parser_api_schema(
schema=args['schema'],
)
@@ -307,7 +299,7 @@ class ToolApiProviderPreviousTestApi(Resource):
args = parser.parse_args()
- return ToolManageService.test_api_tool_preview(
+ return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
args['provider_name'] if args['provider_name'] else '',
args['tool_name'],
@@ -317,6 +309,153 @@ class ToolApiProviderPreviousTestApi(Resource):
args['schema'],
)
+class ToolWorkflowProviderCreateApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ if not current_user.is_admin_or_owner:
+ raise Forbidden()
+
+ user_id = current_user.id
+ tenant_id = current_user.current_tenant_id
+
+ reqparser = reqparse.RequestParser()
+ reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json')
+ reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
+ reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
+ reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
+ reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
+ reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
+ reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
+ reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
+
+ args = reqparser.parse_args()
+
+ return WorkflowToolManageService.create_workflow_tool(
+ user_id,
+ tenant_id,
+ args['workflow_app_id'],
+ args['name'],
+ args['label'],
+ args['icon'],
+ args['description'],
+ args['parameters'],
+ args['privacy_policy'],
+ args.get('labels', []),
+ )
+
+class ToolWorkflowProviderUpdateApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ if not current_user.is_admin_or_owner:
+ raise Forbidden()
+
+ user_id = current_user.id
+ tenant_id = current_user.current_tenant_id
+
+ reqparser = reqparse.RequestParser()
+ reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
+ reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
+ reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
+ reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
+ reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
+ reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
+ reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
+ reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
+
+ args = reqparser.parse_args()
+
+ if not args['workflow_tool_id']:
+ raise ValueError('incorrect workflow_tool_id')
+
+ return WorkflowToolManageService.update_workflow_tool(
+ user_id,
+ tenant_id,
+ args['workflow_tool_id'],
+ args['name'],
+ args['label'],
+ args['icon'],
+ args['description'],
+ args['parameters'],
+ args['privacy_policy'],
+ args.get('labels', []),
+ )
+
+class ToolWorkflowProviderDeleteApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ if not current_user.is_admin_or_owner:
+ raise Forbidden()
+
+ user_id = current_user.id
+ tenant_id = current_user.current_tenant_id
+
+ reqparser = reqparse.RequestParser()
+ reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
+
+ args = reqparser.parse_args()
+
+ return WorkflowToolManageService.delete_workflow_tool(
+ user_id,
+ tenant_id,
+ args['workflow_tool_id'],
+ )
+
+class ToolWorkflowProviderGetApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self):
+ user_id = current_user.id
+ tenant_id = current_user.current_tenant_id
+
+ parser = reqparse.RequestParser()
+ parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args')
+ parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args')
+
+ args = parser.parse_args()
+
+ if args.get('workflow_tool_id'):
+ tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
+ user_id,
+ tenant_id,
+ args['workflow_tool_id'],
+ )
+ elif args.get('workflow_app_id'):
+ tool = WorkflowToolManageService.get_workflow_tool_by_app_id(
+ user_id,
+ tenant_id,
+ args['workflow_app_id'],
+ )
+ else:
+ raise ValueError('incorrect workflow_tool_id or workflow_app_id')
+
+ return jsonable_encoder(tool)
+
+class ToolWorkflowProviderListToolApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self):
+ user_id = current_user.id
+ tenant_id = current_user.current_tenant_id
+
+ parser = reqparse.RequestParser()
+ parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args')
+
+ args = parser.parse_args()
+
+ return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools(
+ user_id,
+ tenant_id,
+ args['workflow_tool_id'],
+ ))
+
class ToolBuiltinListApi(Resource):
@setup_required
@login_required
@@ -325,7 +464,7 @@ class ToolBuiltinListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
- return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_builtin_tools(
+ return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)])
@@ -338,20 +477,43 @@ class ToolApiListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
- return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_api_tools(
+ return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id,
)])
+
+class ToolWorkflowListApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self):
+ user_id = current_user.id
+ tenant_id = current_user.current_tenant_id
+ return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools(
+ user_id,
+ tenant_id,
+ )])
+
+class ToolLabelsApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self):
+ return jsonable_encoder(ToolLabelsService.list_tool_labels())
+
+# tool provider
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
+
+# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools')
api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete')
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update')
api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin//credentials')
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema')
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon')
-api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model//icon')
-api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
+
+# api tool provider
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
@@ -361,5 +523,15 @@ api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/g
api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema')
api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre')
+# workflow tool provider
+api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create')
+api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update')
+api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete')
+api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get')
+api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools')
+
api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin')
-api.add_resource(ToolApiListApi, '/workspaces/current/tools/api')
\ No newline at end of file
+api.add_resource(ToolApiListApi, '/workspaces/current/tools/api')
+api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow')
+
+api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels')
\ No newline at end of file
diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py
index cd72872b62..7a11a45ae8 100644
--- a/api/controllers/console/workspace/workspace.py
+++ b/api/controllers/console/workspace/workspace.py
@@ -161,13 +161,13 @@ class CustomConfigWorkspaceApi(Resource):
parser.add_argument('replace_webapp_logo', type=str, location='json')
args = parser.parse_args()
+ tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
+
custom_config_dict = {
'remove_webapp_brand': args['remove_webapp_brand'],
- 'replace_webapp_logo': args['replace_webapp_logo'],
+ 'replace_webapp_logo': args['replace_webapp_logo'] if args['replace_webapp_logo'] is not None else tenant.custom_config_dict.get('replace_webapp_logo') ,
}
- tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
-
tenant.custom_config_dict = custom_config_dict
db.session.commit()
diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py
index 08f382b0a7..c8b44cfa38 100644
--- a/api/controllers/service_api/app/message.py
+++ b/api/controllers/service_api/app/message.py
@@ -97,7 +97,7 @@ class MessageListApi(Resource):
class MessageFeedbackApi(Resource):
- @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
@@ -114,7 +114,7 @@ class MessageFeedbackApi(Resource):
class MessageSuggestedApi(Resource):
- @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py
index 8ae81531ae..819512edf0 100644
--- a/api/controllers/service_api/wraps.py
+++ b/api/controllers/service_api/wraps.py
@@ -8,7 +8,7 @@ from flask import current_app, request
from flask_login import user_logged_in
from flask_restful import Resource
from pydantic import BaseModel
-from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
+from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db
from libs.login import _get_user
@@ -39,17 +39,17 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
if not app_model:
- raise NotFound()
+ raise Forbidden("The app no longer exists.")
if app_model.status != 'normal':
- raise NotFound()
+ raise Forbidden("The app's status is abnormal.")
if not app_model.enable_api:
- raise NotFound()
+ raise Forbidden("The app's API service has been disabled.")
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
if tenant.status == TenantStatus.ARCHIVE:
- raise NotFound()
+ raise Forbidden("The workspace's status is archived.")
kwargs['app_model'] = app_model
diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py
index b6d46d4081..aa19bdc034 100644
--- a/api/controllers/web/__init__.py
+++ b/api/controllers/web/__init__.py
@@ -6,4 +6,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
api = ExternalApi(bp)
-from . import app, audio, completion, conversation, file, message, passport, saved_message, site, workflow
+from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow
diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py
index 2586f2e6ec..91d9015c33 100644
--- a/api/controllers/web/app.py
+++ b/api/controllers/web/app.py
@@ -1,14 +1,10 @@
-import json
-
from flask import current_app
from flask_restful import fields, marshal_with
from controllers.web import api
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
-from extensions.ext_database import db
-from models.model import App, AppMode, AppModelConfig
-from models.tools import ApiToolProvider
+from models.model import App, AppMode
from services.app_service import AppService
diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py
index e0074c452f..ca6d774e9d 100644
--- a/api/controllers/web/audio.py
+++ b/api/controllers/web/audio.py
@@ -74,7 +74,7 @@ class TextApi(WebApiResource):
app_model=app_model,
text=request.form['text'],
end_user=end_user.external_user_id,
- voice=request.form.get('voice'),
+ voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)
diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py
index 390e3fe7d1..bc87f51051 100644
--- a/api/controllers/web/error.py
+++ b/api/controllers/web/error.py
@@ -115,3 +115,9 @@ class UnsupportedFileTypeError(BaseHTTPException):
error_code = 'unsupported_file_type'
description = "File type not allowed."
code = 415
+
+
+class WebSSOAuthRequiredError(BaseHTTPException):
+ error_code = 'web_sso_auth_required'
+ description = "Web SSO authentication required."
+ code = 401
diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py
new file mode 100644
index 0000000000..65842d78c6
--- /dev/null
+++ b/api/controllers/web/feature.py
@@ -0,0 +1,12 @@
+from flask_restful import Resource
+
+from controllers.web import api
+from services.feature_service import FeatureService
+
+
+class SystemFeatureApi(Resource):
+ def get(self):
+ return FeatureService.get_system_features().dict()
+
+
+api.add_resource(SystemFeatureApi, '/system-features')
diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py
index 92b28d8125..ccc8683a79 100644
--- a/api/controllers/web/passport.py
+++ b/api/controllers/web/passport.py
@@ -5,14 +5,21 @@ from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.web import api
+from controllers.web.error import WebSSOAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from models.model import App, EndUser, Site
+from services.feature_service import FeatureService
class PassportResource(Resource):
"""Base resource for passport."""
def get(self):
+
+ system_features = FeatureService.get_system_features()
+ if system_features.sso_enforced_for_web:
+ raise WebSSOAuthRequiredError()
+
app_code = request.headers.get('X-App-Code')
if app_code is None:
raise Unauthorized('X-App-Code header is missing.')
@@ -28,7 +35,7 @@ class PassportResource(Resource):
app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model or app_model.status != 'normal' or not app_model.enable_site:
raise NotFound()
-
+
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
@@ -36,6 +43,7 @@ class PassportResource(Resource):
is_anonymous=True,
session_id=generate_session_id(),
)
+
db.session.add(end_user)
db.session.commit()
@@ -53,8 +61,10 @@ class PassportResource(Resource):
'access_token': tk,
}
+
api.add_resource(PassportResource, '/passport')
+
def generate_session_id():
"""
Generate a unique session ID.
diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py
index 49b0a8bfc0..a084b56b08 100644
--- a/api/controllers/web/site.py
+++ b/api/controllers/web/site.py
@@ -31,6 +31,7 @@ class AppSiteApi(WebApiResource):
'description': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
+ 'custom_disclaimer': fields.String,
'default_language': fields.String,
'prompt_public': fields.Boolean
}
diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py
index bdaa476f34..f5ab49d7e1 100644
--- a/api/controllers/web/wraps.py
+++ b/api/controllers/web/wraps.py
@@ -2,11 +2,13 @@ from functools import wraps
from flask import request
from flask_restful import Resource
-from werkzeug.exceptions import NotFound, Unauthorized
+from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
+from controllers.web.error import WebSSOAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from models.model import App, EndUser, Site
+from services.feature_service import FeatureService
def validate_jwt_token(view=None):
@@ -21,34 +23,60 @@ def validate_jwt_token(view=None):
return decorator(view)
return decorator
+
def decode_jwt_token():
- auth_header = request.headers.get('Authorization')
- if auth_header is None:
- raise Unauthorized('Authorization header is missing.')
+ system_features = FeatureService.get_system_features()
- if ' ' not in auth_header:
- raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.')
-
- auth_scheme, tk = auth_header.split(None, 1)
- auth_scheme = auth_scheme.lower()
+ try:
+ auth_header = request.headers.get('Authorization')
+ if auth_header is None:
+ raise Unauthorized('Authorization header is missing.')
- if auth_scheme != 'bearer':
- raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.')
- decoded = PassportService().verify(tk)
- app_code = decoded.get('app_code')
- app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
- site = db.session.query(Site).filter(Site.code == app_code).first()
- if not app_model:
- raise NotFound()
- if not app_code or not site:
- raise Unauthorized('Site URL is no longer valid.')
- if app_model.enable_site is False:
- raise Unauthorized('Site is disabled.')
- end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
- if not end_user:
- raise NotFound()
+ if ' ' not in auth_header:
+ raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.')
+
+ auth_scheme, tk = auth_header.split(None, 1)
+ auth_scheme = auth_scheme.lower()
+
+ if auth_scheme != 'bearer':
+ raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.')
+ decoded = PassportService().verify(tk)
+ app_code = decoded.get('app_code')
+ app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
+ site = db.session.query(Site).filter(Site.code == app_code).first()
+ if not app_model:
+ raise NotFound()
+ if not app_code or not site:
+ raise BadRequest('Site URL is no longer valid.')
+ if app_model.enable_site is False:
+ raise BadRequest('Site is disabled.')
+ end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
+ if not end_user:
+ raise NotFound()
+
+ _validate_web_sso_token(decoded, system_features)
+
+ return app_model, end_user
+ except Unauthorized as e:
+ if system_features.sso_enforced_for_web:
+ raise WebSSOAuthRequiredError()
+
+ raise Unauthorized(e.description)
+
+
+def _validate_web_sso_token(decoded, system_features):
+ # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login
+ if system_features.sso_enforced_for_web:
+ source = decoded.get('token_source')
+ if not source or source != 'sso':
+ raise WebSSOAuthRequiredError()
+
+ # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
+ if not system_features.sso_enforced_for_web:
+ source = decoded.get('token_source')
+ if source and source == 'sso':
+ raise Unauthorized('sso token expired.')
- return app_model, end_user
class WebApiResource(Resource):
method_decorators = [validate_jwt_token]
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index 485633cab1..d228a3ac29 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -39,6 +39,7 @@ from core.tools.entities.tool_entities import (
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db
from models.model import Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
@@ -128,6 +129,8 @@ class BaseAgentRunner(AppRunner):
self.files = application_generate_entity.files
else:
self.files = []
+ self.query = None
+ self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> AgentChatAppGenerateEntity:
@@ -165,6 +168,7 @@ class BaseAgentRunner(AppRunner):
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
+ invoke_from=self.application_generate_entity.invoke_from
)
tool_entity.load_variables(self.variables_pool)
@@ -183,21 +187,11 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
- parameter_type = 'string'
+ parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
enum = []
- if parameter.type == ToolParameter.ToolParameterType.STRING:
- parameter_type = 'string'
- elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
- parameter_type = 'boolean'
- elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
- parameter_type = 'number'
- elif parameter.type == ToolParameter.ToolParameterType.SELECT:
- for option in parameter.options:
- enum.append(option.value)
- parameter_type = 'string'
- else:
- raise ValueError(f"parameter type {parameter.type} is not supported")
-
+ if parameter.type == ToolParameter.ToolParameterType.SELECT:
+ enum = [option.value for option in parameter.options]
+
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
@@ -278,20 +272,10 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
- parameter_type = 'string'
+ parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
enum = []
- if parameter.type == ToolParameter.ToolParameterType.STRING:
- parameter_type = 'string'
- elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
- parameter_type = 'boolean'
- elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
- parameter_type = 'number'
- elif parameter.type == ToolParameter.ToolParameterType.SELECT:
- for option in parameter.options:
- enum.append(option.value)
- parameter_type = 'string'
- else:
- raise ValueError(f"parameter type {parameter.type} is not supported")
+ if parameter.type == ToolParameter.ToolParameterType.SELECT:
+ enum = [option.value for option in parameter.options]
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
@@ -463,7 +447,7 @@ class BaseAgentRunner(AppRunner):
for message in messages:
if message.id == self.message.id:
continue
-
+
result.append(self.organize_agent_user_prompt(message))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py
index 12554f42b3..40cfb20d0b 100644
--- a/api/core/agent/cot_agent_runner.py
+++ b/api/core/agent/cot_agent_runner.py
@@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
+from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
@@ -121,7 +122,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
raise ValueError("failed to invoke llm")
usage_dict = {}
- react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks)
+ react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response='',
thought='',
@@ -189,7 +190,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not scratchpad.action:
# failed to extract action, return final answer directly
- final_answer = scratchpad.agent_response or ''
+ final_answer = ''
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
@@ -373,7 +374,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message
- def _organize_historic_prompt_messages(self) -> list[PromptMessage]:
+ def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
organize historic prompt messages
"""
@@ -381,6 +382,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
scratchpad: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None
+ self.history_prompt_messages = AgentHistoryPromptTransform(
+ model_config=self.model_config,
+ prompt_messages=current_session_messages or [],
+ history_messages=self.history_prompt_messages,
+ memory=self.memory
+ ).get_prompt()
+
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
current_scratchpad = AgentScratchpadUnit(
diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py
index a904f3e641..e8b05373ab 100644
--- a/api/core/agent/cot_chat_agent_runner.py
+++ b/api/core/agent/cot_chat_agent_runner.py
@@ -32,9 +32,6 @@ class CotChatAgentRunner(CotAgentRunner):
# organize system prompt
system_message = self._organize_system_prompt()
- # organize historic prompt messages
- historic_messages = self._historic_prompt_messages
-
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
if not agent_scratchpad:
@@ -57,6 +54,13 @@ class CotChatAgentRunner(CotAgentRunner):
query_messages = UserPromptMessage(content=self._query)
if assistant_messages:
+ # organize historic prompt messages
+ historic_messages = self._organize_historic_prompt_messages([
+ system_message,
+ query_messages,
+ *assistant_messages,
+ UserPromptMessage(content='continue')
+ ])
messages = [
system_message,
*historic_messages,
@@ -65,6 +69,8 @@ class CotChatAgentRunner(CotAgentRunner):
UserPromptMessage(content='continue')
]
else:
+ # organize historic prompt messages
+ historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
messages = [system_message, *historic_messages, query_messages]
# join all messages
diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py
index 3f0298d5a3..9e6eb54f4f 100644
--- a/api/core/agent/cot_completion_agent_runner.py
+++ b/api/core/agent/cot_completion_agent_runner.py
@@ -19,11 +19,11 @@ class CotCompletionAgentRunner(CotAgentRunner):
return system_prompt
- def _organize_historic_prompt(self) -> str:
+ def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
"""
Organize historic prompt
"""
- historic_prompt_messages = self._historic_prompt_messages
+ historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
historic_prompt = ""
for message in historic_prompt_messages:
diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py
index 5284faa02e..5274224de5 100644
--- a/api/core/agent/entities.py
+++ b/api/core/agent/entities.py
@@ -8,7 +8,7 @@ class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
- provider_type: Literal["builtin", "api"]
+ provider_type: Literal["builtin", "api", "workflow"]
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}
diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py
index a9b3a80073..d416a319a4 100644
--- a/api/core/agent/fc_agent_runner.py
+++ b/api/core/agent/fc_agent_runner.py
@@ -17,6 +17,7 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
+from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message
@@ -24,21 +25,18 @@ from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
+
def run(self,
message: Message, query: str, **kwargs: Any
) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
+ self.query = query
app_generate_entity = self.application_generate_entity
app_config = self.app_config
- prompt_template = app_config.prompt_template.simple_prompt_template or ''
- prompt_messages = self.history_prompt_messages
- prompt_messages = self._init_system_message(prompt_template, prompt_messages)
- prompt_messages = self._organize_user_query(query, prompt_messages)
-
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
@@ -81,6 +79,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
# recalc llm max tokens
+ prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
@@ -203,7 +202,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
else:
assistant_message.content = response
- prompt_messages.append(assistant_message)
+ self._current_thoughts.append(assistant_message)
# save thought
self.save_agent_thought(
@@ -265,12 +264,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
}
tool_responses.append(tool_response)
- prompt_messages = self._organize_assistant_message(
- tool_call_id=tool_call_id,
- tool_call_name=tool_call_name,
- tool_response=tool_response['tool_response'],
- prompt_messages=prompt_messages,
- )
+ if tool_response['tool_response'] is not None:
+ self._current_thoughts.append(
+ ToolPromptMessage(
+ content=tool_response['tool_response'],
+ tool_call_id=tool_call_id,
+ name=tool_call_name,
+ )
+ )
if len(tool_responses) > 0:
# save agent thought
@@ -300,8 +301,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1
- prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
-
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
@@ -393,24 +392,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return prompt_messages
- def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
- prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
- """
- Organize assistant message
- """
- prompt_messages = deepcopy(prompt_messages)
-
- if tool_response is not None:
- prompt_messages.append(
- ToolPromptMessage(
- content=tool_response,
- tool_call_id=tool_call_id,
- name=tool_call_name,
- )
- )
-
- return prompt_messages
-
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
@@ -428,4 +409,26 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for content in prompt_message.content
])
- return prompt_messages
\ No newline at end of file
+ return prompt_messages
+
+ def _organize_prompt_messages(self):
+ prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
+ self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
+ query_prompt_messages = self._organize_user_query(self.query, [])
+
+ self.history_prompt_messages = AgentHistoryPromptTransform(
+ model_config=self.model_config,
+ prompt_messages=[*query_prompt_messages, *self._current_thoughts],
+ history_messages=self.history_prompt_messages,
+ memory=self.memory
+ ).get_prompt()
+
+ prompt_messages = [
+ *self.history_prompt_messages,
+ *query_prompt_messages,
+ *self._current_thoughts
+ ]
+ if len(self._current_thoughts) != 0:
+ # clear messages after the first iteration
+ prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
+ return prompt_messages
diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py
index 91ac41143b..5a445e9e59 100644
--- a/api/core/agent/output_parser/cot_output_parser.py
+++ b/api/core/agent/output_parser/cot_output_parser.py
@@ -9,7 +9,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser:
@classmethod
- def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None]) -> \
+ def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(json_str):
try:
@@ -58,6 +58,8 @@ class CotAgentOutputParser:
thought_idx = 0
for response in llm_response:
+ if response.delta.usage:
+ usage_dict['usage'] = response.delta.usage
response = response.delta.message.content
if not isinstance(response, str):
continue
diff --git a/api/core/tools/prompt/template.py b/api/core/agent/prompt/template.py
similarity index 100%
rename from api/core/tools/prompt/template.py
rename to api/core/agent/prompt/template.py
diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py
index 66d4a3275b..1ca8b1e3b8 100644
--- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py
+++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py
@@ -11,7 +11,7 @@ class SensitiveWordAvoidanceConfigManager:
if not sensitive_word_avoidance_dict:
return None
- if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
+ if sensitive_word_avoidance_dict.get('enabled'):
return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get('config'),
diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py
index a48316728b..f271aeed0c 100644
--- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py
+++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py
@@ -1,7 +1,7 @@
from typing import Optional
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
-from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
+from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
class AgentConfigManager:
diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py
index 101e25d582..d6b6d89416 100644
--- a/api/core/app/app_config/entities.py
+++ b/api/core/app/app_config/entities.py
@@ -239,4 +239,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
"""
Workflow UI Based App Config Entity.
"""
- workflow_id: str
+ workflow_id: str
\ No newline at end of file
diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py
index acffa6e9e7..2049b573cd 100644
--- a/api/core/app/app_config/features/file_upload/manager.py
+++ b/api/core/app/app_config/features/file_upload/manager.py
@@ -14,7 +14,7 @@ class FileUploadConfigManager:
"""
file_upload_dict = config.get('file_upload')
if file_upload_dict:
- if 'image' in file_upload_dict and file_upload_dict['image']:
+ if file_upload_dict.get('image'):
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
image_config = {
'number_limits': file_upload_dict['image']['number_limits'],
diff --git a/api/core/app/app_config/features/more_like_this/manager.py b/api/core/app/app_config/features/more_like_this/manager.py
index ec2a9a6796..2ba99a5c40 100644
--- a/api/core/app/app_config/features/more_like_this/manager.py
+++ b/api/core/app/app_config/features/more_like_this/manager.py
@@ -9,7 +9,7 @@ class MoreLikeThisConfigManager:
more_like_this = False
more_like_this_dict = config.get('more_like_this')
if more_like_this_dict:
- if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
+ if more_like_this_dict.get('enabled'):
more_like_this = True
return more_like_this
diff --git a/api/core/app/app_config/features/retrieval_resource/manager.py b/api/core/app/app_config/features/retrieval_resource/manager.py
index 0694cb954e..fca58e12e8 100644
--- a/api/core/app/app_config/features/retrieval_resource/manager.py
+++ b/api/core/app/app_config/features/retrieval_resource/manager.py
@@ -4,7 +4,7 @@ class RetrievalResourceConfigManager:
show_retrieve_source = False
retriever_resource_dict = config.get('retriever_resource')
if retriever_resource_dict:
- if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
+ if retriever_resource_dict.get('enabled'):
show_retrieve_source = True
return show_retrieve_source
diff --git a/api/core/app/app_config/features/speech_to_text/manager.py b/api/core/app/app_config/features/speech_to_text/manager.py
index b98699bfff..88b4be25d3 100644
--- a/api/core/app/app_config/features/speech_to_text/manager.py
+++ b/api/core/app/app_config/features/speech_to_text/manager.py
@@ -9,7 +9,7 @@ class SpeechToTextConfigManager:
speech_to_text = False
speech_to_text_dict = config.get('speech_to_text')
if speech_to_text_dict:
- if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
+ if speech_to_text_dict.get('enabled'):
speech_to_text = True
return speech_to_text
diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py
index 5aacd3b32d..c6cab01220 100644
--- a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py
+++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py
@@ -9,7 +9,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
suggested_questions_after_answer = False
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
if suggested_questions_after_answer_dict:
- if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
+ if suggested_questions_after_answer_dict.get('enabled'):
suggested_questions_after_answer = True
return suggested_questions_after_answer
diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py
index 1ff31034ad..b516fa46ab 100644
--- a/api/core/app/app_config/features/text_to_speech/manager.py
+++ b/api/core/app/app_config/features/text_to_speech/manager.py
@@ -12,7 +12,7 @@ class TextToSpeechConfigManager:
text_to_speech = False
text_to_speech_dict = config.get('text_to_speech')
if text_to_speech_dict:
- if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
+ if text_to_speech_dict.get('enabled'):
text_to_speech = TextToSpeechEntity(
enabled=text_to_speech_dict.get('enabled'),
voice=text_to_speech_dict.get('voice'),
diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py
index e5cf585f82..3b1ee3578d 100644
--- a/api/core/app/apps/advanced_chat/app_generator.py
+++ b/api/core/app/apps/advanced_chat/app_generator.py
@@ -66,7 +66,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# parse files
- files = args['files'] if 'files' in args and args['files'] else []
+ files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
@@ -98,6 +98,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
extras=extras
)
+ return self._generate(
+ app_model=app_model,
+ workflow=workflow,
+ user=user,
+ invoke_from=invoke_from,
+ application_generate_entity=application_generate_entity,
+ conversation=conversation,
+ stream=stream
+ )
+
+ def single_iteration_generate(self, app_model: App,
+ workflow: Workflow,
+ node_id: str,
+ user: Account,
+ args: dict,
+ stream: bool = True) \
+ -> Union[dict, Generator[dict, None, None]]:
+ """
+ Generate App response.
+
+ :param app_model: App
+ :param workflow: Workflow
+ :param user: account or end user
+ :param args: request args
+ :param invoke_from: invoke from source
+ :param stream: is stream
+ """
+ if not node_id:
+ raise ValueError('node_id is required')
+
+ if args.get('inputs') is None:
+ raise ValueError('inputs is required')
+
+ extras = {
+ "auto_generate_conversation_name": False
+ }
+
+ # get conversation
+ conversation = None
+ if args.get('conversation_id'):
+ conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
+
+ # convert to app config
+ app_config = AdvancedChatAppConfigManager.get_app_config(
+ app_model=app_model,
+ workflow=workflow
+ )
+
+ # init application generate entity
+ application_generate_entity = AdvancedChatAppGenerateEntity(
+ task_id=str(uuid.uuid4()),
+ app_config=app_config,
+ conversation_id=conversation.id if conversation else None,
+ inputs={},
+ query='',
+ files=[],
+ user_id=user.id,
+ stream=stream,
+ invoke_from=InvokeFrom.DEBUGGER,
+ extras=extras,
+ single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
+ node_id=node_id,
+ inputs=args['inputs']
+ )
+ )
+
+ return self._generate(
+ app_model=app_model,
+ workflow=workflow,
+ user=user,
+ invoke_from=InvokeFrom.DEBUGGER,
+ application_generate_entity=application_generate_entity,
+ conversation=conversation,
+ stream=stream
+ )
+
+ def _generate(self, app_model: App,
+ workflow: Workflow,
+ user: Union[Account, EndUser],
+ invoke_from: InvokeFrom,
+ application_generate_entity: AdvancedChatAppGenerateEntity,
+ conversation: Conversation = None,
+ stream: bool = True) \
+ -> Union[dict, Generator[dict, None, None]]:
is_first_conversation = False
if not conversation:
is_first_conversation = True
@@ -167,18 +251,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"""
with flask_app.app_context():
try:
- # get conversation and message
- conversation = self._get_conversation(conversation_id)
- message = self._get_message(message_id)
-
- # chatbot app
runner = AdvancedChatAppRunner()
- runner.run(
- application_generate_entity=application_generate_entity,
- queue_manager=queue_manager,
- conversation=conversation,
- message=message
- )
+ if application_generate_entity.single_iteration_run:
+ single_iteration_run = application_generate_entity.single_iteration_run
+ runner.single_iteration_run(
+ app_id=application_generate_entity.app_config.app_id,
+ workflow_id=application_generate_entity.app_config.workflow_id,
+ queue_manager=queue_manager,
+ inputs=single_iteration_run.inputs,
+ node_id=single_iteration_run.node_id,
+ user_id=application_generate_entity.user_id
+ )
+ else:
+ # get conversation and message
+ conversation = self._get_conversation(conversation_id)
+ message = self._get_message(message_id)
+
+ # chatbot app
+ runner = AdvancedChatAppRunner()
+ runner.run(
+ application_generate_entity=application_generate_entity,
+ queue_manager=queue_manager,
+ conversation=conversation,
+ message=message
+ )
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py
index d858dcac12..de3632894d 100644
--- a/api/core/app/apps/advanced_chat/app_runner.py
+++ b/api/core/app/apps/advanced_chat/app_runner.py
@@ -102,6 +102,7 @@ class AdvancedChatAppRunner(AppRunner):
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
+ invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.QUERY: query,
@@ -109,6 +110,35 @@ class AdvancedChatAppRunner(AppRunner):
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
},
+ callbacks=workflow_callbacks,
+ call_depth=application_generate_entity.call_depth
+ )
+
+ def single_iteration_run(self, app_id: str, workflow_id: str,
+ queue_manager: AppQueueManager,
+ inputs: dict, node_id: str, user_id: str) -> None:
+ """
+ Single iteration run
+ """
+ app_record: App = db.session.query(App).filter(App.id == app_id).first()
+ if not app_record:
+ raise ValueError("App not found")
+
+ workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
+ if not workflow:
+ raise ValueError("Workflow not initialized")
+
+ workflow_callbacks = [WorkflowEventTriggerCallback(
+ queue_manager=queue_manager,
+ workflow=workflow
+ )]
+
+ workflow_engine_manager = WorkflowEngineManager()
+ workflow_engine_manager.single_step_run_iteration_workflow_node(
+ workflow=workflow,
+ node_id=node_id,
+ user_id=user_id,
+ user_inputs=inputs,
callbacks=workflow_callbacks
)
diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py
index 80e8e22e88..08069332ba 100644
--- a/api/core/app/apps/advanced_chat/generate_response_converter.py
+++ b/api/core/app/apps/advanced_chat/generate_response_converter.py
@@ -8,6 +8,8 @@ from core.app.entities.task_entities import (
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
+ NodeFinishStreamResponse,
+ NodeStartStreamResponse,
PingStreamResponse,
)
@@ -111,6 +113,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
+ elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
+ response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 67d8fe5cb1..7c70afc2ae 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -12,6 +12,9 @@ from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
+ QueueIterationCompletedEvent,
+ QueueIterationNextEvent,
+ QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
@@ -64,6 +67,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow: Workflow
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariable, Any]
+ _iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
@@ -103,6 +107,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
usage=LLMUsage.empty_usage()
)
+ self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None
@@ -204,6 +209,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
+ # reset current route position to 0
+ self._task_state.current_stream_generate_state.current_route_position = 0
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
@@ -225,6 +232,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
+
+ if isinstance(event, QueueNodeFailedEvent):
+ yield from self._handle_iteration_exception(
+ task_id=self._application_generate_entity.task_id,
+ error=f'Child node failed: {event.error}'
+ )
+ elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
+ if isinstance(event, QueueIterationNextEvent):
+ # clear ran node execution infos of current iteration
+ iteration_relations = self._iteration_nested_relations.get(event.node_id)
+ if iteration_relations:
+ for node_id in iteration_relations:
+ self._task_state.ran_node_execution_infos.pop(node_id, None)
+
+ yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
+ self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(event)
if workflow_run:
@@ -263,10 +286,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
- # elif isinstance(event, QueueMessageFileEvent):
- # response = self._message_file_to_stream_response(event)
- # if response:
- # yield response
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
@@ -342,7 +361,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
id=self._message.id,
**extras
)
-
+
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
"""
Get stream generate routes.
@@ -372,7 +391,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
return stream_generate_routes
-
+
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
@@ -391,6 +410,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
ingoing_edges.append(edge)
if not ingoing_edges:
+ # check if it's the first node in the iteration
+ target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
+ if not target_node:
+ return []
+
+ node_iteration_id = target_node.get('data', {}).get('iteration_id')
+ # get iteration start node id
+ for node in nodes:
+ if node.get('id') == node_iteration_id:
+ if node.get('data', {}).get('start_node_id') == target_node_id:
+ return [target_node_id]
+
return []
start_node_ids = []
@@ -401,14 +432,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
continue
node_type = source_node.get('data', {}).get('type')
+ node_iteration_id = source_node.get('data', {}).get('iteration_id')
+ iteration_start_node_id = None
+ if node_iteration_id:
+ iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
+ iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
+
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
- NodeType.QUESTION_CLASSIFIER.value
+ NodeType.QUESTION_CLASSIFIER.value,
+ NodeType.ITERATION.value,
+ NodeType.LOOP.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
- elif node_type == NodeType.START.value:
+ elif node_type == NodeType.START.value or \
+ node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
@@ -417,7 +457,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
+
+ def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
+ """
+ Get iteration nested relations.
+ :param graph: graph
+ :return:
+ """
+ nodes = graph.get('nodes')
+ iteration_ids = [node.get('id') for node in nodes
+ if node.get('data', {}).get('type') in [
+ NodeType.ITERATION.value,
+ NodeType.LOOP.value,
+ ]]
+
+ return {
+ iteration_id: [
+ node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
+ ] for iteration_id in iteration_ids
+ }
+
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
@@ -425,7 +485,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
if self._task_state.current_stream_generate_state:
route_chunks = self._task_state.current_stream_generate_state.generate_route[
- self._task_state.current_stream_generate_state.current_route_position:]
+ self._task_state.current_stream_generate_state.current_route_position:
+ ]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
@@ -458,13 +519,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
-
+
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
+ value = None
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
@@ -476,6 +538,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if route_chunk_node_id == 'sys':
# system variable
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
+ elif route_chunk_node_id in self._iteration_nested_relations:
+ # it's a iteration variable
+ if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
+ continue
+ iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
+ iterator = iteration_state.inputs
+ if not iterator:
+ continue
+ iterator_selector = iterator.get('iterator_selector', [])
+ if value_selector[1] == 'index':
+ value = iteration_state.current_index
+ elif value_selector[1] == 'item':
+ value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
+ iterator_selector) else None
else:
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
@@ -505,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else:
value = value.get(key)
- if value:
+ if value is not None:
text = ''
if isinstance(value, str | int | float):
text = str(value)
diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py
index fef719a086..78fe077e6b 100644
--- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py
+++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py
@@ -1,8 +1,11 @@
-from typing import Optional
+from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
+ QueueIterationCompletedEvent,
+ QueueIterationNextEvent,
+ QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
@@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
), PublishFrom.APPLICATION_MANAGER
)
+ def on_workflow_iteration_started(self,
+ node_id: str,
+ node_type: NodeType,
+ node_run_index: int = 1,
+ node_data: Optional[BaseNodeData] = None,
+ inputs: dict = None,
+ predecessor_node_id: Optional[str] = None,
+ metadata: Optional[dict] = None) -> None:
+ """
+ Publish iteration started
+ """
+ self._queue_manager.publish(
+ QueueIterationStartEvent(
+ node_id=node_id,
+ node_type=node_type,
+ node_run_index=node_run_index,
+ node_data=node_data,
+ inputs=inputs,
+ predecessor_node_id=predecessor_node_id,
+ metadata=metadata
+ ),
+ PublishFrom.APPLICATION_MANAGER
+ )
+
+ def on_workflow_iteration_next(self, node_id: str,
+ node_type: NodeType,
+ index: int,
+ node_run_index: int,
+ output: Optional[Any]) -> None:
+ """
+ Publish iteration next
+ """
+ self._queue_manager._publish(
+ QueueIterationNextEvent(
+ node_id=node_id,
+ node_type=node_type,
+ index=index,
+ node_run_index=node_run_index,
+ output=output
+ ),
+ PublishFrom.APPLICATION_MANAGER
+ )
+
+ def on_workflow_iteration_completed(self, node_id: str,
+ node_type: NodeType,
+ node_run_index: int,
+ outputs: dict) -> None:
+ """
+ Publish iteration completed
+ """
+ self._queue_manager._publish(
+ QueueIterationCompletedEvent(
+ node_id=node_id,
+ node_type=node_type,
+ node_run_index=node_run_index,
+ outputs=outputs
+ ),
+ PublishFrom.APPLICATION_MANAGER
+ )
+
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py
index 847d314409..fb4c28a855 100644
--- a/api/core/app/apps/agent_chat/app_generator.py
+++ b/api/core/app/apps/agent_chat/app_generator.py
@@ -83,7 +83,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
)
# parse files
- files = args['files'] if 'files' in args and args['files'] else []
+ files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
@@ -115,7 +115,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
- extras=extras
+ extras=extras,
+ call_depth=0
)
# init generate records
diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py
index 9d88c834e6..20ae6ff676 100644
--- a/api/core/app/apps/base_app_generator.py
+++ b/api/core/app/apps/base_app_generator.py
@@ -13,7 +13,9 @@ class BaseAppGenerator:
for variable_config in variables:
variable = variable_config.variable
- if variable not in user_inputs or not user_inputs[variable]:
+ if (variable not in user_inputs
+ or user_inputs[variable] is None
+ or (isinstance(user_inputs[variable], str) and user_inputs[variable] == '')):
if variable_config.required:
raise ValueError(f"{variable} is required in input form")
else:
@@ -22,7 +24,7 @@ class BaseAppGenerator:
value = user_inputs[variable]
- if value:
+ if value is not None:
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
@@ -44,7 +46,7 @@ class BaseAppGenerator:
if value and isinstance(value, str):
filtered_inputs[variable] = value.replace('\x00', '')
else:
- filtered_inputs[variable] = value if value else None
+ filtered_inputs[variable] = value if value is not None else None
return filtered_inputs
diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py
index f1f426b27e..545463c8bd 100644
--- a/api/core/app/apps/base_app_runner.py
+++ b/api/core/app/apps/base_app_runner.py
@@ -1,6 +1,6 @@
import time
from collections.abc import Generator
-from typing import Optional, Union, cast
+from typing import Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -16,11 +16,11 @@ from core.app.features.hosting_moderation.hosting_moderation import HostingModer
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
@@ -45,8 +45,11 @@ class AppRunner:
:param query: query
:return:
"""
- model_type_instance = model_config.provider_model_bundle.model_type_instance
- model_type_instance = cast(LargeLanguageModel, model_type_instance)
+ # Invoke model
+ model_instance = ModelInstance(
+ provider_model_bundle=model_config.provider_model_bundle,
+ model=model_config.model
+ )
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
@@ -73,9 +76,7 @@ class AppRunner:
query=query
)
- prompt_tokens = model_type_instance.get_num_tokens(
- model_config.model,
- model_config.credentials,
+ prompt_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
@@ -89,8 +90,10 @@ class AppRunner:
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
- model_type_instance = model_config.provider_model_bundle.model_type_instance
- model_type_instance = cast(LargeLanguageModel, model_type_instance)
+ model_instance = ModelInstance(
+ provider_model_bundle=model_config.provider_model_bundle,
+ model=model_config.model
+ )
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
@@ -107,9 +110,7 @@ class AppRunner:
if max_tokens is None:
max_tokens = 0
- prompt_tokens = model_type_instance.get_num_tokens(
- model_config.model,
- model_config.credentials,
+ prompt_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py
index e67901cca8..4ad26e8506 100644
--- a/api/core/app/apps/chat/app_generator.py
+++ b/api/core/app/apps/chat/app_generator.py
@@ -80,7 +80,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
)
# parse files
- files = args['files'] if 'files' in args and args['files'] else []
+ files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py
index 5f93afcad7..31ce4d70b6 100644
--- a/api/core/app/apps/completion/app_generator.py
+++ b/api/core/app/apps/completion/app_generator.py
@@ -75,7 +75,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
)
# parse files
- files = args['files'] if 'files' in args and args['files'] else []
+ files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py
index a9b038ab51..c4324978d8 100644
--- a/api/core/app/apps/workflow/app_generator.py
+++ b/api/core/app/apps/workflow/app_generator.py
@@ -34,7 +34,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
- stream: bool = True) \
+ stream: bool = True,
+ call_depth: int = 0) \
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@@ -49,7 +50,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
inputs = args['inputs']
# parse files
- files = args['files'] if 'files' in args and args['files'] else []
+ files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
@@ -75,9 +76,38 @@ class WorkflowAppGenerator(BaseAppGenerator):
files=file_objs,
user_id=user.id,
stream=stream,
- invoke_from=invoke_from
+ invoke_from=invoke_from,
+ call_depth=call_depth
)
+ return self._generate(
+ app_model=app_model,
+ workflow=workflow,
+ user=user,
+ application_generate_entity=application_generate_entity,
+ invoke_from=invoke_from,
+ stream=stream,
+ call_depth=call_depth
+ )
+
+ def _generate(self, app_model: App,
+ workflow: Workflow,
+ user: Union[Account, EndUser],
+ application_generate_entity: WorkflowAppGenerateEntity,
+ invoke_from: InvokeFrom,
+ stream: bool = True,
+ call_depth: int = 0) \
+ -> Union[dict, Generator[dict, None, None]]:
+ """
+ Generate App response.
+
+ :param app_model: App
+ :param workflow: Workflow
+ :param user: account or end user
+ :param application_generate_entity: application generate entity
+ :param invoke_from: invoke from source
+ :param stream: is stream
+ """
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
@@ -109,6 +139,64 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from=invoke_from
)
+ def single_iteration_generate(self, app_model: App,
+ workflow: Workflow,
+ node_id: str,
+ user: Account,
+ args: dict,
+ stream: bool = True) \
+ -> Union[dict, Generator[dict, None, None]]:
+ """
+ Generate App response.
+
+ :param app_model: App
+ :param workflow: Workflow
+ :param user: account or end user
+ :param args: request args
+ :param invoke_from: invoke from source
+ :param stream: is stream
+ """
+ if not node_id:
+ raise ValueError('node_id is required')
+
+ if args.get('inputs') is None:
+ raise ValueError('inputs is required')
+
+ extras = {
+ "auto_generate_conversation_name": False
+ }
+
+ # convert to app config
+ app_config = WorkflowAppConfigManager.get_app_config(
+ app_model=app_model,
+ workflow=workflow
+ )
+
+ # init application generate entity
+ application_generate_entity = WorkflowAppGenerateEntity(
+ task_id=str(uuid.uuid4()),
+ app_config=app_config,
+ inputs={},
+ files=[],
+ user_id=user.id,
+ stream=stream,
+ invoke_from=InvokeFrom.DEBUGGER,
+ extras=extras,
+ single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
+ node_id=node_id,
+ inputs=args['inputs']
+ )
+ )
+
+ return self._generate(
+ app_model=app_model,
+ workflow=workflow,
+ user=user,
+ invoke_from=InvokeFrom.DEBUGGER,
+ application_generate_entity=application_generate_entity,
+ stream=stream
+ )
+
def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager) -> None:
@@ -123,10 +211,21 @@ class WorkflowAppGenerator(BaseAppGenerator):
try:
# workflow app
runner = WorkflowAppRunner()
- runner.run(
- application_generate_entity=application_generate_entity,
- queue_manager=queue_manager
- )
+ if application_generate_entity.single_iteration_run:
+ single_iteration_run = application_generate_entity.single_iteration_run
+ runner.single_iteration_run(
+ app_id=application_generate_entity.app_config.app_id,
+ workflow_id=application_generate_entity.app_config.workflow_id,
+ queue_manager=queue_manager,
+ inputs=single_iteration_run.inputs,
+ node_id=single_iteration_run.node_id,
+ user_id=application_generate_entity.user_id
+ )
+ else:
+ runner.run(
+ application_generate_entity=application_generate_entity,
+ queue_manager=queue_manager
+ )
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py
index 9d854afe35..050319e552 100644
--- a/api/core/app/apps/workflow/app_runner.py
+++ b/api/core/app/apps/workflow/app_runner.py
@@ -73,11 +73,44 @@ class WorkflowAppRunner:
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
+ invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id
},
+ callbacks=workflow_callbacks,
+ call_depth=application_generate_entity.call_depth
+ )
+
+ def single_iteration_run(self, app_id: str, workflow_id: str,
+ queue_manager: AppQueueManager,
+ inputs: dict, node_id: str, user_id: str) -> None:
+ """
+ Single iteration run
+ """
+ app_record: App = db.session.query(App).filter(App.id == app_id).first()
+ if not app_record:
+ raise ValueError("App not found")
+
+ if not app_record.workflow_id:
+ raise ValueError("Workflow not initialized")
+
+ workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
+ if not workflow:
+ raise ValueError("Workflow not initialized")
+
+ workflow_callbacks = [WorkflowEventTriggerCallback(
+ queue_manager=queue_manager,
+ workflow=workflow
+ )]
+
+ workflow_engine_manager = WorkflowEngineManager()
+ workflow_engine_manager.single_step_run_iteration_workflow_node(
+ workflow=workflow,
+ node_id=node_id,
+ user_id=user_id,
+ user_inputs=inputs,
callbacks=workflow_callbacks
)
diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py
index d907b82c99..88bde58ba0 100644
--- a/api/core/app/apps/workflow/generate_response_converter.py
+++ b/api/core/app/apps/workflow/generate_response_converter.py
@@ -5,6 +5,8 @@ from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
ErrorStreamResponse,
+ NodeFinishStreamResponse,
+ NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
@@ -68,4 +70,24 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
:param stream_response: stream response
:return:
"""
- return cls.convert_stream_full_response(stream_response)
+ for chunk in stream_response:
+ chunk = cast(WorkflowAppStreamResponse, chunk)
+ sub_stream_response = chunk.stream_response
+
+ if isinstance(sub_stream_response, PingStreamResponse):
+ yield 'ping'
+ continue
+
+ response_chunk = {
+ 'event': sub_stream_response.event.value,
+ 'workflow_run_id': chunk.workflow_run_id,
+ }
+
+ if isinstance(sub_stream_response, ErrorStreamResponse):
+ data = cls._error_to_stream_response(sub_stream_response.err)
+ response_chunk.update(data)
+ elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
+ response_chunk.update(sub_stream_response.to_ignore_detail_dict())
+ else:
+ response_chunk.update(sub_stream_response.to_dict())
+ yield json.dumps(response_chunk)
diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py
index a7061a77bb..8d961e0993 100644
--- a/api/core/app/apps/workflow/generate_task_pipeline.py
+++ b/api/core/app/apps/workflow/generate_task_pipeline.py
@@ -9,6 +9,9 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import (
QueueErrorEvent,
+ QueueIterationCompletedEvent,
+ QueueIterationNextEvent,
+ QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
@@ -58,6 +61,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariable, Any]
+ _iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
@@ -85,8 +89,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
SystemVariable.USER_ID: user_id
}
- self._task_state = WorkflowTaskState()
+ self._task_state = WorkflowTaskState(
+ iteration_nested_node_ids=[]
+ )
self._stream_generate_nodes = self._get_stream_generate_nodes()
+ self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@@ -191,6 +198,22 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
+
+ if isinstance(event, QueueNodeFailedEvent):
+ yield from self._handle_iteration_exception(
+ task_id=self._application_generate_entity.task_id,
+ error=f'Child node failed: {event.error}'
+ )
+ elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
+ if isinstance(event, QueueIterationNextEvent):
+ # clear ran node execution infos of current iteration
+ iteration_relations = self._iteration_nested_relations.get(event.node_id)
+ if iteration_relations:
+ for node_id in iteration_relations:
+ self._task_state.ran_node_execution_infos.pop(node_id, None)
+
+ yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
+ self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(event)
@@ -331,13 +354,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
continue
node_type = source_node.get('data', {}).get('type')
+ node_iteration_id = source_node.get('data', {}).get('iteration_id')
+ iteration_start_node_id = None
+ if node_iteration_id:
+ iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
+ iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
+
if node_type in [
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
- elif node_type == NodeType.START.value:
+ elif node_type == NodeType.START.value or \
+ node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
@@ -411,3 +441,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
return False
return True
+
+ def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
+ """
+ Get iteration nested relations.
+ :param graph: graph
+ :return:
+ """
+ nodes = graph.get('nodes')
+
+ iteration_ids = [node.get('id') for node in nodes
+ if node.get('data', {}).get('type') in [
+ NodeType.ITERATION.value,
+ NodeType.LOOP.value,
+ ]]
+
+ return {
+ iteration_id: [
+ node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
+ ] for iteration_id in iteration_ids
+ }
+
\ No newline at end of file
diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py
index 2048abe464..e423a40bcb 100644
--- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py
+++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py
@@ -1,8 +1,11 @@
-from typing import Optional
+from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
+ QueueIterationCompletedEvent,
+ QueueIterationNextEvent,
+ QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
@@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
), PublishFrom.APPLICATION_MANAGER
)
+ def on_workflow_iteration_started(self,
+ node_id: str,
+ node_type: NodeType,
+ node_run_index: int = 1,
+ node_data: Optional[BaseNodeData] = None,
+ inputs: dict = None,
+ predecessor_node_id: Optional[str] = None,
+ metadata: Optional[dict] = None) -> None:
+ """
+ Publish iteration started
+ """
+ self._queue_manager.publish(
+ QueueIterationStartEvent(
+ node_id=node_id,
+ node_type=node_type,
+ node_run_index=node_run_index,
+ node_data=node_data,
+ inputs=inputs,
+ predecessor_node_id=predecessor_node_id,
+ metadata=metadata
+ ),
+ PublishFrom.APPLICATION_MANAGER
+ )
+
+ def on_workflow_iteration_next(self, node_id: str,
+ node_type: NodeType,
+ index: int,
+ node_run_index: int,
+ output: Optional[Any]) -> None:
+ """
+ Publish iteration next
+ """
+ self._queue_manager.publish(
+ QueueIterationNextEvent(
+ node_id=node_id,
+ node_type=node_type,
+ index=index,
+ node_run_index=node_run_index,
+ output=output
+ ),
+ PublishFrom.APPLICATION_MANAGER
+ )
+
+ def on_workflow_iteration_completed(self, node_id: str,
+ node_type: NodeType,
+ node_run_index: int,
+ outputs: dict) -> None:
+ """
+ Publish iteration completed
+ """
+ self._queue_manager.publish(
+ QueueIterationCompletedEvent(
+ node_id=node_id,
+ node_type=node_type,
+ node_run_index=node_run_index,
+ outputs=outputs
+ ),
+ PublishFrom.APPLICATION_MANAGER
+ )
+
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py
index 4627c21c7a..f617c671e9 100644
--- a/api/core/app/apps/workflow_logging_callback.py
+++ b/api/core/app/apps/workflow_logging_callback.py
@@ -102,6 +102,39 @@ class WorkflowLoggingCallback(BaseWorkflowCallback):
self.print_text(text, color="pink", end="")
+ def on_workflow_iteration_started(self,
+ node_id: str,
+ node_type: NodeType,
+ node_run_index: int = 1,
+ node_data: Optional[BaseNodeData] = None,
+ inputs: dict = None,
+ predecessor_node_id: Optional[str] = None,
+ metadata: Optional[dict] = None) -> None:
+ """
+ Publish iteration started
+ """
+ self.print_text("\n[on_workflow_iteration_started]", color='blue')
+ self.print_text(f"Node ID: {node_id}", color='blue')
+
+ def on_workflow_iteration_next(self, node_id: str,
+ node_type: NodeType,
+ index: int,
+ node_run_index: int,
+ output: Optional[dict]) -> None:
+ """
+ Publish iteration next
+ """
+ self.print_text("\n[on_workflow_iteration_next]", color='blue')
+
+ def on_workflow_iteration_completed(self, node_id: str,
+ node_type: NodeType,
+ node_run_index: int,
+ outputs: dict) -> None:
+ """
+ Publish iteration completed
+ """
+ self.print_text("\n[on_workflow_iteration_completed]", color='blue')
+
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py
index 09c62c802c..cc63fa4684 100644
--- a/api/core/app/entities/app_invoke_entities.py
+++ b/api/core/app/entities/app_invoke_entities.py
@@ -80,6 +80,9 @@ class AppGenerateEntity(BaseModel):
stream: bool
invoke_from: InvokeFrom
+ # invoke call depth
+ call_depth: int = 0
+
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}
@@ -126,6 +129,14 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
conversation_id: Optional[str] = None
query: Optional[str] = None
+ class SingleIterationRunEntity(BaseModel):
+ """
+ Single Iteration Run Entity.
+ """
+ node_id: str
+ inputs: dict
+
+ single_iteration_run: Optional[SingleIterationRunEntity] = None
class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
@@ -133,3 +144,12 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: WorkflowUIBasedAppConfig
+
+ class SingleIterationRunEntity(BaseModel):
+ """
+ Single Iteration Run Entity.
+ """
+ node_id: str
+ inputs: dict
+
+ single_iteration_run: Optional[SingleIterationRunEntity] = None
\ No newline at end of file
diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py
index bf174e30e1..47fa2ac19d 100644
--- a/api/core/app/entities/queue_entities.py
+++ b/api/core/app/entities/queue_entities.py
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional
-from pydantic import BaseModel
+from pydantic import BaseModel, validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
@@ -21,6 +21,9 @@ class QueueEvent(Enum):
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
+ ITERATION_START = "iteration_start"
+ ITERATION_NEXT = "iteration_next"
+ ITERATION_COMPLETED = "iteration_completed"
NODE_STARTED = "node_started"
NODE_SUCCEEDED = "node_succeeded"
NODE_FAILED = "node_failed"
@@ -47,6 +50,55 @@ class QueueLLMChunkEvent(AppQueueEvent):
event = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk
+class QueueIterationStartEvent(AppQueueEvent):
+ """
+ QueueIterationStartEvent entity
+ """
+ event = QueueEvent.ITERATION_START
+ node_id: str
+ node_type: NodeType
+ node_data: BaseNodeData
+
+ node_run_index: int
+ inputs: dict = None
+ predecessor_node_id: Optional[str] = None
+ metadata: Optional[dict] = None
+
+class QueueIterationNextEvent(AppQueueEvent):
+ """
+ QueueIterationNextEvent entity
+ """
+ event = QueueEvent.ITERATION_NEXT
+
+ index: int
+ node_id: str
+ node_type: NodeType
+
+ node_run_index: int
+ output: Optional[Any] # output for the current iteration
+
+ @validator('output', pre=True, always=True)
+ def set_output(cls, v):
+ """
+ Set output
+ """
+ if v is None:
+ return None
+ if isinstance(v, int | float | str | bool | dict | list):
+ return v
+ raise ValueError('output must be a valid type')
+
+class QueueIterationCompletedEvent(AppQueueEvent):
+ """
+ QueueIterationCompletedEvent entity
+ """
+ event = QueueEvent.ITERATION_COMPLETED
+
+ node_id: str
+ node_type: NodeType
+
+ node_run_index: int
+ outputs: dict
class QueueTextChunkEvent(AppQueueEvent):
"""
diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py
index 4994efe2e9..5956bc35fa 100644
--- a/api/core/app/entities/task_entities.py
+++ b/api/core/app/entities/task_entities.py
@@ -1,12 +1,14 @@
from enum import Enum
-from typing import Optional
+from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
+from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import GenerateRouteChunk
+from models.workflow import WorkflowNodeExecutionStatus
class WorkflowStreamGenerateNodes(BaseModel):
@@ -65,6 +67,7 @@ class WorkflowTaskState(TaskState):
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
+ iteration_nested_node_ids: list[str] = None
class AdvancedChatTaskState(WorkflowTaskState):
"""
@@ -91,6 +94,9 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
+ ITERATION_STARTED = "iteration_started"
+ ITERATION_NEXT = "iteration_next"
+ ITERATION_COMPLETED = "iteration_completed"
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
@@ -246,6 +252,24 @@ class NodeStartStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
+ def to_ignore_detail_dict(self):
+ return {
+ "event": self.event.value,
+ "task_id": self.task_id,
+ "workflow_run_id": self.workflow_run_id,
+ "data": {
+ "id": self.data.id,
+ "node_id": self.data.node_id,
+ "node_type": self.data.node_type,
+ "title": self.data.title,
+ "index": self.data.index,
+ "predecessor_node_id": self.data.predecessor_node_id,
+ "inputs": None,
+ "created_at": self.data.created_at,
+ "extras": {}
+ }
+ }
+
class NodeFinishStreamResponse(StreamResponse):
"""
@@ -276,6 +300,99 @@ class NodeFinishStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
+ def to_ignore_detail_dict(self):
+ return {
+ "event": self.event.value,
+ "task_id": self.task_id,
+ "workflow_run_id": self.workflow_run_id,
+ "data": {
+ "id": self.data.id,
+ "node_id": self.data.node_id,
+ "node_type": self.data.node_type,
+ "title": self.data.title,
+ "index": self.data.index,
+ "predecessor_node_id": self.data.predecessor_node_id,
+ "inputs": None,
+ "process_data": None,
+ "outputs": None,
+ "status": self.data.status,
+ "error": None,
+ "elapsed_time": self.data.elapsed_time,
+ "execution_metadata": None,
+ "created_at": self.data.created_at,
+ "finished_at": self.data.finished_at,
+ "files": []
+ }
+ }
+
+class IterationNodeStartStreamResponse(StreamResponse):
+ """
+ NodeStartStreamResponse entity
+ """
+ class Data(BaseModel):
+ """
+ Data entity
+ """
+ id: str
+ node_id: str
+ node_type: str
+ title: str
+ created_at: int
+ extras: dict = {}
+ metadata: dict = {}
+ inputs: dict = {}
+
+ event: StreamEvent = StreamEvent.ITERATION_STARTED
+ workflow_run_id: str
+ data: Data
+
+class IterationNodeNextStreamResponse(StreamResponse):
+ """
+ NodeStartStreamResponse entity
+ """
+ class Data(BaseModel):
+ """
+ Data entity
+ """
+ id: str
+ node_id: str
+ node_type: str
+ title: str
+ index: int
+ created_at: int
+ pre_iteration_output: Optional[Any]
+ extras: dict = {}
+
+ event: StreamEvent = StreamEvent.ITERATION_NEXT
+ workflow_run_id: str
+ data: Data
+
+class IterationNodeCompletedStreamResponse(StreamResponse):
+ """
+ NodeStartStreamResponse entity
+ """
+ class Data(BaseModel):
+ """
+ Data entity
+ """
+ id: str
+ node_id: str
+ node_type: str
+ title: str
+ outputs: Optional[dict]
+ created_at: int
+ extras: dict = None
+ inputs: dict = None
+ status: WorkflowNodeExecutionStatus
+ error: Optional[str]
+ elapsed_time: float
+ total_tokens: int
+ finished_at: int
+ steps: int
+
+ event: StreamEvent = StreamEvent.ITERATION_COMPLETED
+ workflow_run_id: str
+ data: Data
class TextChunkStreamResponse(StreamResponse):
"""
@@ -411,3 +528,23 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str
data: Data
+
+class WorkflowIterationState(BaseModel):
+ """
+ WorkflowIterationState entity
+ """
+ class Data(BaseModel):
+ """
+ Data entity
+ """
+ parent_iteration_id: Optional[str] = None
+ iteration_id: str
+ current_index: int
+ iteration_steps_boundary: list[int] = None
+ node_execution_id: str
+ started_at: float
+ inputs: dict = None
+ total_tokens: int = 0
+ node_data: BaseNodeData
+
+ current_iterations: dict[str, Data] = None
\ No newline at end of file
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index a7dbb4754c..f71470edb2 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -37,6 +37,7 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
+from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@@ -317,29 +318,30 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
"""
model_config = self._model_config
model = model_config.model
- model_type_instance = model_config.provider_model_bundle.model_type_instance
- model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+ model_instance = ModelInstance(
+ provider_model_bundle=model_config.provider_model_bundle,
+ model=model_config.model
+ )
# calculate num tokens
prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
- prompt_tokens = model_type_instance.get_num_tokens(
- model,
- model_config.credentials,
+ prompt_tokens = model_instance.get_llm_num_tokens(
self._task_state.llm_result.prompt_messages
)
completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
- completion_tokens = model_type_instance.get_num_tokens(
- model,
- model_config.credentials,
+ completion_tokens = model_instance.get_llm_num_tokens(
[self._task_state.llm_result.message]
)
credentials = model_config.credentials
# transform usage
+ model_type_instance = model_config.provider_model_bundle.model_type_instance
+ model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
model,
credentials,
diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py
index 48ff34fef9..978a318279 100644
--- a/api/core/app/task_pipeline/workflow_cycle_manage.py
+++ b/api/core/app/task_pipeline/workflow_cycle_manage.py
@@ -1,9 +1,9 @@
import json
import time
from datetime import datetime, timezone
-from typing import Any, Optional, Union, cast
+from typing import Optional, Union, cast
-from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
+from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
@@ -13,18 +13,17 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
- AdvancedChatTaskState,
NodeExecutionInfo,
NodeFinishStreamResponse,
NodeStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
- WorkflowTaskState,
)
+from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.tool_manager import ToolManager
-from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
+from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
@@ -42,13 +41,7 @@ from models.workflow import (
)
-class WorkflowCycleManage:
- _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
- _workflow: Workflow
- _user: Union[Account, EndUser]
- _task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
- _workflow_system_variables: dict[SystemVariable, Any]
-
+class WorkflowCycleManage(WorkflowIterationCycleManage):
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
@@ -237,6 +230,7 @@ class WorkflowCycleManage:
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
+ execution_metadata: Optional[dict] = None
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
@@ -255,6 +249,8 @@ class WorkflowCycleManage:
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
+ workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
+ if execution_metadata else None
db.session.commit()
db.session.refresh(workflow_node_execution)
@@ -444,6 +440,23 @@ class WorkflowCycleManage:
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
+
+ execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
+
+ if self._iteration_state and self._iteration_state.current_iterations:
+ if not execution_metadata:
+ execution_metadata = {}
+ current_iteration_data = None
+ for iteration_node_id in self._iteration_state.current_iterations:
+ data = self._iteration_state.current_iterations[iteration_node_id]
+ if data.parent_iteration_id == None:
+ current_iteration_data = data
+ break
+
+ if current_iteration_data:
+ execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
+ execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
+
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
@@ -451,12 +464,18 @@ class WorkflowCycleManage:
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs,
- execution_metadata=event.execution_metadata
+ execution_metadata=execution_metadata
)
- if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
+ if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
self._task_state.total_tokens += (
- int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
+ int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
+
+ if self._iteration_state:
+ for iteration_node_id in self._iteration_state.current_iterations:
+ data = self._iteration_state.current_iterations[iteration_node_id]
+ if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
+ data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
if workflow_node_execution.node_type == NodeType.LLM.value:
outputs = workflow_node_execution.outputs_dict
@@ -469,7 +488,8 @@ class WorkflowCycleManage:
error=event.error,
inputs=event.inputs,
process_data=event.process_data,
- outputs=event.outputs
+ outputs=event.outputs,
+ execution_metadata=execution_metadata
)
db.session.close()
diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py
new file mode 100644
index 0000000000..545f31fddf
--- /dev/null
+++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py
@@ -0,0 +1,16 @@
+from typing import Any, Union
+
+from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
+from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
+from core.workflow.entities.node_entities import SystemVariable
+from models.account import Account
+from models.model import EndUser
+from models.workflow import Workflow
+
+
+class WorkflowCycleStateManager:
+ _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
+ _workflow: Workflow
+ _user: Union[Account, EndUser]
+ _task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
+ _workflow_system_variables: dict[SystemVariable, Any]
\ No newline at end of file
diff --git a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py b/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py
new file mode 100644
index 0000000000..55e3e03173
--- /dev/null
+++ b/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py
@@ -0,0 +1,281 @@
+import json
+import time
+from collections.abc import Generator
+from typing import Optional, Union
+
+from core.app.entities.queue_entities import (
+ QueueIterationCompletedEvent,
+ QueueIterationNextEvent,
+ QueueIterationStartEvent,
+)
+from core.app.entities.task_entities import (
+ IterationNodeCompletedStreamResponse,
+ IterationNodeNextStreamResponse,
+ IterationNodeStartStreamResponse,
+ NodeExecutionInfo,
+ WorkflowIterationState,
+)
+from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
+from core.workflow.entities.node_entities import NodeType
+from extensions.ext_database import db
+from models.workflow import (
+ WorkflowNodeExecution,
+ WorkflowNodeExecutionStatus,
+ WorkflowNodeExecutionTriggeredFrom,
+ WorkflowRun,
+)
+
+
+class WorkflowIterationCycleManage(WorkflowCycleStateManager):
+ _iteration_state: WorkflowIterationState = None
+
+ def _init_iteration_state(self) -> WorkflowIterationState:
+ if not self._iteration_state:
+ self._iteration_state = WorkflowIterationState(
+ current_iterations={}
+ )
+
+ def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
+ -> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
+ """
+ Handle iteration to stream response
+ :param task_id: task id
+ :param event: iteration event
+ :return:
+ """
+ if isinstance(event, QueueIterationStartEvent):
+ return IterationNodeStartStreamResponse(
+ task_id=task_id,
+ workflow_run_id=self._task_state.workflow_run_id,
+ data=IterationNodeStartStreamResponse.Data(
+ id=event.node_id,
+ node_id=event.node_id,
+ node_type=event.node_type.value,
+ title=event.node_data.title,
+ created_at=int(time.time()),
+ extras={},
+ inputs=event.inputs,
+ metadata=event.metadata
+ )
+ )
+ elif isinstance(event, QueueIterationNextEvent):
+ current_iteration = self._iteration_state.current_iterations[event.node_id]
+
+ return IterationNodeNextStreamResponse(
+ task_id=task_id,
+ workflow_run_id=self._task_state.workflow_run_id,
+ data=IterationNodeNextStreamResponse.Data(
+ id=event.node_id,
+ node_id=event.node_id,
+ node_type=event.node_type.value,
+ title=current_iteration.node_data.title,
+ index=event.index,
+ pre_iteration_output=event.output,
+ created_at=int(time.time()),
+ extras={}
+ )
+ )
+ elif isinstance(event, QueueIterationCompletedEvent):
+ current_iteration = self._iteration_state.current_iterations[event.node_id]
+
+ return IterationNodeCompletedStreamResponse(
+ task_id=task_id,
+ workflow_run_id=self._task_state.workflow_run_id,
+ data=IterationNodeCompletedStreamResponse.Data(
+ id=event.node_id,
+ node_id=event.node_id,
+ node_type=event.node_type.value,
+ title=current_iteration.node_data.title,
+ outputs=event.outputs,
+ created_at=int(time.time()),
+ extras={},
+ inputs=current_iteration.inputs,
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
+ error=None,
+ elapsed_time=time.perf_counter() - current_iteration.started_at,
+ total_tokens=current_iteration.total_tokens,
+ finished_at=int(time.time()),
+ steps=current_iteration.current_index
+ )
+ )
+
+ def _init_iteration_execution_from_workflow_run(self,
+ workflow_run: WorkflowRun,
+ node_id: str,
+ node_type: NodeType,
+ node_title: str,
+ node_run_index: int = 1,
+ inputs: Optional[dict] = None,
+ predecessor_node_id: Optional[str] = None
+ ) -> WorkflowNodeExecution:
+ workflow_node_execution = WorkflowNodeExecution(
+ tenant_id=workflow_run.tenant_id,
+ app_id=workflow_run.app_id,
+ workflow_id=workflow_run.workflow_id,
+ triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
+ workflow_run_id=workflow_run.id,
+ predecessor_node_id=predecessor_node_id,
+ index=node_run_index,
+ node_id=node_id,
+ node_type=node_type.value,
+ inputs=json.dumps(inputs) if inputs else None,
+ title=node_title,
+ status=WorkflowNodeExecutionStatus.RUNNING.value,
+ created_by_role=workflow_run.created_by_role,
+ created_by=workflow_run.created_by,
+ execution_metadata=json.dumps({
+ 'started_run_index': node_run_index + 1,
+ 'current_index': 0,
+ 'steps_boundary': [],
+ })
+ )
+
+ db.session.add(workflow_node_execution)
+ db.session.commit()
+ db.session.refresh(workflow_node_execution)
+ db.session.close()
+
+ return workflow_node_execution
+
+ def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
+ if isinstance(event, QueueIterationStartEvent):
+ return self._handle_iteration_started(event)
+ elif isinstance(event, QueueIterationNextEvent):
+ return self._handle_iteration_next(event)
+ elif isinstance(event, QueueIterationCompletedEvent):
+ return self._handle_iteration_completed(event)
+
+ def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
+ self._init_iteration_state()
+
+ workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
+ workflow_node_execution = self._init_iteration_execution_from_workflow_run(
+ workflow_run=workflow_run,
+ node_id=event.node_id,
+ node_type=NodeType.ITERATION,
+ node_title=event.node_data.title,
+ node_run_index=event.node_run_index,
+ inputs=event.inputs,
+ predecessor_node_id=event.predecessor_node_id
+ )
+
+ latest_node_execution_info = NodeExecutionInfo(
+ workflow_node_execution_id=workflow_node_execution.id,
+ node_type=NodeType.ITERATION,
+ start_at=time.perf_counter()
+ )
+
+ self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
+ self._task_state.latest_node_execution_info = latest_node_execution_info
+
+ self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
+ parent_iteration_id=None,
+ iteration_id=event.node_id,
+ current_index=0,
+ iteration_steps_boundary=[],
+ node_execution_id=workflow_node_execution.id,
+ started_at=time.perf_counter(),
+ inputs=event.inputs,
+ total_tokens=0,
+ node_data=event.node_data
+ )
+
+ db.session.close()
+
+ return workflow_node_execution
+
+ def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
+ if event.node_id not in self._iteration_state.current_iterations:
+ return
+ current_iteration = self._iteration_state.current_iterations[event.node_id]
+ current_iteration.current_index = event.index
+ current_iteration.iteration_steps_boundary.append(event.node_run_index)
+ workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
+ WorkflowNodeExecution.id == current_iteration.node_execution_id
+ ).first()
+
+ original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
+ if original_node_execution_metadata:
+ original_node_execution_metadata['current_index'] = event.index
+ original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
+ original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
+ workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
+
+ db.session.commit()
+
+ db.session.close()
+
+ def _handle_iteration_completed(self, event: QueueIterationCompletedEvent) -> WorkflowNodeExecution:
+ if event.node_id not in self._iteration_state.current_iterations:
+ return
+
+ current_iteration = self._iteration_state.current_iterations[event.node_id]
+ workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
+ WorkflowNodeExecution.id == current_iteration.node_execution_id
+ ).first()
+
+ workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
+ workflow_node_execution.outputs = json.dumps(event.outputs) if event.outputs else None
+ workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
+
+ original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
+ if original_node_execution_metadata:
+ original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
+ original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
+ workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
+
+ db.session.commit()
+
+ # remove current iteration
+ self._iteration_state.current_iterations.pop(event.node_id, None)
+
+ # set latest node execution info
+ latest_node_execution_info = NodeExecutionInfo(
+ workflow_node_execution_id=workflow_node_execution.id,
+ node_type=NodeType.ITERATION,
+ start_at=time.perf_counter()
+ )
+
+ self._task_state.latest_node_execution_info = latest_node_execution_info
+
+ db.session.close()
+
+ def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
+ """
+ Handle iteration exception
+ """
+ if not self._iteration_state or not self._iteration_state.current_iterations:
+ return
+
+ for node_id, current_iteration in self._iteration_state.current_iterations.items():
+ workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
+ WorkflowNodeExecution.id == current_iteration.node_execution_id
+ ).first()
+
+ workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
+ workflow_node_execution.error = error
+ workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
+
+ db.session.commit()
+ db.session.close()
+
+ yield IterationNodeCompletedStreamResponse(
+ task_id=task_id,
+ workflow_run_id=self._task_state.workflow_run_id,
+ data=IterationNodeCompletedStreamResponse.Data(
+ id=node_id,
+ node_id=node_id,
+ node_type=NodeType.ITERATION.value,
+ title=current_iteration.node_data.title,
+ outputs={},
+ created_at=int(time.time()),
+ extras={},
+ inputs=current_iteration.inputs,
+ status=WorkflowNodeExecutionStatus.FAILED,
+ error=error,
+ elapsed_time=time.perf_counter() - current_iteration.started_at,
+ total_tokens=current_iteration.total_tokens,
+ finished_at=int(time.time()),
+ steps=current_iteration.current_index
+ )
+ )
\ No newline at end of file
diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py
index 05719e5b8d..9a797c1c95 100644
--- a/api/core/entities/model_entities.py
+++ b/api/core/entities/model_entities.py
@@ -16,6 +16,7 @@ class ModelStatus(Enum):
NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission"
+ DISABLED = "disabled"
class SimpleModelProviderEntity(BaseModel):
@@ -43,12 +44,19 @@ class SimpleModelProviderEntity(BaseModel):
)
-class ModelWithProviderEntity(ProviderModel):
+class ProviderModelWithStatusEntity(ProviderModel):
+ """
+ Model class for model response.
+ """
+ status: ModelStatus
+ load_balancing_enabled: bool = False
+
+
+class ModelWithProviderEntity(ProviderModelWithStatusEntity):
"""
Model with provider entity.
"""
provider: SimpleModelProviderEntity
- status: ModelStatus
class DefaultModelProviderEntity(BaseModel):
diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py
index 303034693d..ec1b2d0d48 100644
--- a/api/core/entities/provider_configuration.py
+++ b/api/core/entities/provider_configuration.py
@@ -1,6 +1,7 @@
import datetime
import json
import logging
+from collections import defaultdict
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Optional
@@ -8,7 +9,12 @@ from typing import Optional
from pydantic import BaseModel
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
-from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
+from core.entities.provider_entities import (
+ CustomConfiguration,
+ ModelSettings,
+ SystemConfiguration,
+ SystemConfigurationStatus,
+)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
@@ -22,7 +28,14 @@ from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from extensions.ext_database import db
-from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
+from models.provider import (
+ LoadBalancingModelConfig,
+ Provider,
+ ProviderModel,
+ ProviderModelSetting,
+ ProviderType,
+ TenantPreferredModelProvider,
+)
logger = logging.getLogger(__name__)
@@ -39,6 +52,7 @@ class ProviderConfiguration(BaseModel):
using_provider_type: ProviderType
system_configuration: SystemConfiguration
custom_configuration: CustomConfiguration
+ model_settings: list[ModelSettings]
def __init__(self, **data):
super().__init__(**data)
@@ -62,6 +76,14 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
+ if self.model_settings:
+ # check if model is disabled by admin
+ for model_setting in self.model_settings:
+ if (model_setting.model_type == model_type
+ and model_setting.model == model):
+ if not model_setting.enabled:
+ raise ValueError(f'Model {model} is disabled.')
+
if self.using_provider_type == ProviderType.SYSTEM:
restrict_models = []
for quota_configuration in self.system_configuration.quota_configurations:
@@ -80,15 +102,17 @@ class ProviderConfiguration(BaseModel):
return copy_credentials
else:
+ credentials = None
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
- return model_configuration.credentials
+ credentials = model_configuration.credentials
+ break
if self.custom_configuration.provider:
- return self.custom_configuration.provider.credentials
- else:
- return None
+ credentials = self.custom_configuration.provider.credentials
+
+ return credentials
def get_system_configuration_status(self) -> SystemConfigurationStatus:
"""
@@ -130,7 +154,7 @@ class ProviderConfiguration(BaseModel):
return credentials
# Obfuscate credentials
- return self._obfuscated_credentials(
+ return self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
@@ -151,7 +175,7 @@ class ProviderConfiguration(BaseModel):
).first()
# Get provider credential secret variables
- provider_credential_secret_variables = self._extract_secret_variables(
+ provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
)
@@ -274,7 +298,7 @@ class ProviderConfiguration(BaseModel):
return credentials
# Obfuscate credentials
- return self._obfuscated_credentials(
+ return self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
@@ -302,7 +326,7 @@ class ProviderConfiguration(BaseModel):
).first()
# Get provider credential secret variables
- provider_credential_secret_variables = self._extract_secret_variables(
+ provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
)
@@ -402,6 +426,160 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache.delete()
+ def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
+ """
+ Enable model.
+ :param model_type: model type
+ :param model: model name
+ :return:
+ """
+ model_setting = db.session.query(ProviderModelSetting) \
+ .filter(
+ ProviderModelSetting.tenant_id == self.tenant_id,
+ ProviderModelSetting.provider_name == self.provider.provider,
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+ ProviderModelSetting.model_name == model
+ ).first()
+
+ if model_setting:
+ model_setting.enabled = True
+ model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+ db.session.commit()
+ else:
+ model_setting = ProviderModelSetting(
+ tenant_id=self.tenant_id,
+ provider_name=self.provider.provider,
+ model_type=model_type.to_origin_model_type(),
+ model_name=model,
+ enabled=True
+ )
+ db.session.add(model_setting)
+ db.session.commit()
+
+ return model_setting
+
+ def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
+ """
+ Disable model.
+ :param model_type: model type
+ :param model: model name
+ :return:
+ """
+ model_setting = db.session.query(ProviderModelSetting) \
+ .filter(
+ ProviderModelSetting.tenant_id == self.tenant_id,
+ ProviderModelSetting.provider_name == self.provider.provider,
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+ ProviderModelSetting.model_name == model
+ ).first()
+
+ if model_setting:
+ model_setting.enabled = False
+ model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+ db.session.commit()
+ else:
+ model_setting = ProviderModelSetting(
+ tenant_id=self.tenant_id,
+ provider_name=self.provider.provider,
+ model_type=model_type.to_origin_model_type(),
+ model_name=model,
+ enabled=False
+ )
+ db.session.add(model_setting)
+ db.session.commit()
+
+ return model_setting
+
+ def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
+ """
+ Get provider model setting.
+ :param model_type: model type
+ :param model: model name
+ :return:
+ """
+ return db.session.query(ProviderModelSetting) \
+ .filter(
+ ProviderModelSetting.tenant_id == self.tenant_id,
+ ProviderModelSetting.provider_name == self.provider.provider,
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+ ProviderModelSetting.model_name == model
+ ).first()
+
+ def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
+ """
+ Enable model load balancing.
+ :param model_type: model type
+ :param model: model name
+ :return:
+ """
+ load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
+ .filter(
+ LoadBalancingModelConfig.tenant_id == self.tenant_id,
+ LoadBalancingModelConfig.provider_name == self.provider.provider,
+ LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+ LoadBalancingModelConfig.model_name == model
+ ).count()
+
+ if load_balancing_config_count <= 1:
+ raise ValueError('Model load balancing configuration must be more than 1.')
+
+ model_setting = db.session.query(ProviderModelSetting) \
+ .filter(
+ ProviderModelSetting.tenant_id == self.tenant_id,
+ ProviderModelSetting.provider_name == self.provider.provider,
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+ ProviderModelSetting.model_name == model
+ ).first()
+
+ if model_setting:
+ model_setting.load_balancing_enabled = True
+ model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+ db.session.commit()
+ else:
+ model_setting = ProviderModelSetting(
+ tenant_id=self.tenant_id,
+ provider_name=self.provider.provider,
+ model_type=model_type.to_origin_model_type(),
+ model_name=model,
+ load_balancing_enabled=True
+ )
+ db.session.add(model_setting)
+ db.session.commit()
+
+ return model_setting
+
+ def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
+ """
+ Disable model load balancing.
+ :param model_type: model type
+ :param model: model name
+ :return:
+ """
+ model_setting = db.session.query(ProviderModelSetting) \
+ .filter(
+ ProviderModelSetting.tenant_id == self.tenant_id,
+ ProviderModelSetting.provider_name == self.provider.provider,
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+ ProviderModelSetting.model_name == model
+ ).first()
+
+ if model_setting:
+ model_setting.load_balancing_enabled = False
+ model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+ db.session.commit()
+ else:
+ model_setting = ProviderModelSetting(
+ tenant_id=self.tenant_id,
+ provider_name=self.provider.provider,
+ model_type=model_type.to_origin_model_type(),
+ model_name=model,
+ load_balancing_enabled=False
+ )
+ db.session.add(model_setting)
+ db.session.commit()
+
+ return model_setting
+
def get_provider_instance(self) -> ModelProvider:
"""
Get provider instance.
@@ -453,7 +631,7 @@ class ProviderConfiguration(BaseModel):
db.session.commit()
- def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
+ def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
"""
Extract secret input form variables.
@@ -467,7 +645,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables
- def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
+ def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
"""
Obfuscated credentials.
@@ -476,7 +654,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# Get provider credential secret variables
- credential_secret_variables = self._extract_secret_variables(
+ credential_secret_variables = self.extract_secret_variables(
credential_form_schemas
)
@@ -522,15 +700,22 @@ class ProviderConfiguration(BaseModel):
else:
model_types = provider_instance.get_provider_schema().supported_model_types
+ # Group model settings by model type and model
+ model_setting_map = defaultdict(dict)
+ for model_setting in self.model_settings:
+ model_setting_map[model_setting.model_type][model_setting.model] = model_setting
+
if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models(
model_types=model_types,
- provider_instance=provider_instance
+ provider_instance=provider_instance,
+ model_setting_map=model_setting_map
)
else:
provider_models = self._get_custom_provider_models(
model_types=model_types,
- provider_instance=provider_instance
+ provider_instance=provider_instance,
+ model_setting_map=model_setting_map
)
if only_active:
@@ -541,18 +726,27 @@ class ProviderConfiguration(BaseModel):
def _get_system_provider_models(self,
model_types: list[ModelType],
- provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
+ provider_instance: ModelProvider,
+ model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
+ -> list[ModelWithProviderEntity]:
"""
Get system provider models.
:param model_types: model types
:param provider_instance: provider instance
+ :param model_setting_map: model setting map
:return:
"""
provider_models = []
for model_type in model_types:
- provider_models.extend(
- [
+ for m in provider_instance.models(model_type):
+ status = ModelStatus.ACTIVE
+ if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
+ model_setting = model_setting_map[m.model_type][m.model]
+ if model_setting.enabled is False:
+ status = ModelStatus.DISABLED
+
+ provider_models.append(
ModelWithProviderEntity(
model=m.model,
label=m.label,
@@ -562,11 +756,9 @@ class ProviderConfiguration(BaseModel):
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
- status=ModelStatus.ACTIVE
+ status=status
)
- for m in provider_instance.models(model_type)
- ]
- )
+ )
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
@@ -586,7 +778,8 @@ class ProviderConfiguration(BaseModel):
break
if should_use_custom_model:
- if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
+ if original_provider_configurate_methods[self.provider.provider] == [
+ ConfigurateMethod.CUSTOMIZABLE_MODEL]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
@@ -611,6 +804,13 @@ class ProviderConfiguration(BaseModel):
if custom_model_schema.model_type not in model_types:
continue
+ status = ModelStatus.ACTIVE
+ if (custom_model_schema.model_type in model_setting_map
+ and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
+ model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
+ if model_setting.enabled is False:
+ status = ModelStatus.DISABLED
+
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
@@ -621,7 +821,7 @@ class ProviderConfiguration(BaseModel):
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
- status=ModelStatus.ACTIVE
+ status=status
)
)
@@ -632,16 +832,20 @@ class ProviderConfiguration(BaseModel):
m.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED
+
return provider_models
def _get_custom_provider_models(self,
model_types: list[ModelType],
- provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
+ provider_instance: ModelProvider,
+ model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
+ -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
:param model_types: model types
:param provider_instance: provider instance
+ :param model_setting_map: model setting map
:return:
"""
provider_models = []
@@ -656,6 +860,16 @@ class ProviderConfiguration(BaseModel):
models = provider_instance.models(model_type)
for m in models:
+ status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
+ load_balancing_enabled = False
+ if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
+ model_setting = model_setting_map[m.model_type][m.model]
+ if model_setting.enabled is False:
+ status = ModelStatus.DISABLED
+
+ if len(model_setting.load_balancing_configs) > 1:
+ load_balancing_enabled = True
+
provider_models.append(
ModelWithProviderEntity(
model=m.model,
@@ -666,7 +880,8 @@ class ProviderConfiguration(BaseModel):
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
- status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
+ status=status,
+ load_balancing_enabled=load_balancing_enabled
)
)
@@ -690,6 +905,17 @@ class ProviderConfiguration(BaseModel):
if not custom_model_schema:
continue
+ status = ModelStatus.ACTIVE
+ load_balancing_enabled = False
+ if (custom_model_schema.model_type in model_setting_map
+ and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
+ model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
+ if model_setting.enabled is False:
+ status = ModelStatus.DISABLED
+
+ if len(model_setting.load_balancing_configs) > 1:
+ load_balancing_enabled = True
+
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
@@ -700,7 +926,8 @@ class ProviderConfiguration(BaseModel):
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
- status=ModelStatus.ACTIVE
+ status=status,
+ load_balancing_enabled=load_balancing_enabled
)
)
diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py
index 114dfaf911..1eaa6ea02c 100644
--- a/api/core/entities/provider_entities.py
+++ b/api/core/entities/provider_entities.py
@@ -72,3 +72,22 @@ class CustomConfiguration(BaseModel):
"""
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []
+
+
+class ModelLoadBalancingConfiguration(BaseModel):
+ """
+ Class for model load balancing configuration.
+ """
+ id: str
+ name: str
+ credentials: dict
+
+
+class ModelSettings(BaseModel):
+ """
+ Model class for model settings.
+ """
+ model: str
+ model_type: ModelType
+ enabled: bool = True
+ load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py
index d2ec555d6c..3a37c6492e 100644
--- a/api/core/extension/extensible.py
+++ b/api/core/extension/extensible.py
@@ -7,7 +7,7 @@ from typing import Any, Optional
from pydantic import BaseModel
-from core.utils.position_helper import sort_to_dict_by_position_map
+from core.helper.position_helper import sort_to_dict_by_position_map
class ExtensionModule(enum.Enum):
diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py
index 06f21c880a..52498eb871 100644
--- a/api/core/file/message_file_parser.py
+++ b/api/core/file/message_file_parser.py
@@ -42,6 +42,8 @@ class MessageFileParser:
raise ValueError('Invalid file url')
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
raise ValueError('Missing file upload_file_id')
+ if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'):
+ raise ValueError('Missing file tool_file_id')
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config)
@@ -149,12 +151,21 @@ class MessageFileParser:
"""
if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
+ if transfer_method != FileTransferMethod.TOOL_FILE:
+ return FileVar(
+ tenant_id=self.tenant_id,
+ type=FileType.value_of(file.get('type')),
+ transfer_method=transfer_method,
+ url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
+ related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
+ extra_config=file_extra_config
+ )
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
transfer_method=transfer_method,
- url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
- related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
+ url=None,
+ related_id=file.get('tool_file_id'),
extra_config=file_extra_config
)
else:
diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py
index 974fde178b..9e454f08d4 100644
--- a/api/core/file/upload_file_parser.py
+++ b/api/core/file/upload_file_parser.py
@@ -77,4 +77,4 @@ class UploadFileParser:
return False
current_time = int(time.time())
- return current_time - int(timestamp) <= 300 # expired after 5 minutes
+ return current_time - int(timestamp) <= current_app.config.get('FILES_ACCESS_TIMEOUT')
diff --git a/api/controllers/console/enterprise/__init__.py b/api/core/helper/code_executor/__init__.py
similarity index 100%
rename from api/controllers/console/enterprise/__init__.py
rename to api/core/helper/code_executor/__init__.py
diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py
index 063a21b192..37573c88ce 100644
--- a/api/core/helper/code_executor/code_executor.py
+++ b/api/core/helper/code_executor/code_executor.py
@@ -1,13 +1,21 @@
+import logging
+import time
+from enum import Enum
+from threading import Lock
from typing import Literal, Optional
-from httpx import post
+from httpx import get, post
from pydantic import BaseModel
from yarl import URL
from config import get_env
-from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer
-from core.helper.code_executor.jinja2_transformer import Jinja2TemplateTransformer
-from core.helper.code_executor.python_transformer import PythonTemplateTransformer
+from core.helper.code_executor.entities import CodeDependency
+from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
+from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
+from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
+from core.helper.code_executor.template_transformer import TemplateTransformer
+
+logger = logging.getLogger(__name__)
# Code Executor
CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
@@ -28,9 +36,38 @@ class CodeExecutionResponse(BaseModel):
data: Data
+class CodeLanguage(str, Enum):
+ PYTHON3 = 'python3'
+ JINJA2 = 'jinja2'
+ JAVASCRIPT = 'javascript'
+
+
class CodeExecutor:
+ dependencies_cache = {}
+ dependencies_cache_lock = Lock()
+
+ code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = {
+ CodeLanguage.PYTHON3: Python3TemplateTransformer,
+ CodeLanguage.JINJA2: Jinja2TemplateTransformer,
+ CodeLanguage.JAVASCRIPT: NodeJsTemplateTransformer,
+ }
+
+ code_language_to_running_language = {
+ CodeLanguage.JAVASCRIPT: 'nodejs',
+ CodeLanguage.JINJA2: CodeLanguage.PYTHON3,
+ CodeLanguage.PYTHON3: CodeLanguage.PYTHON3,
+ }
+
+ supported_dependencies_languages: set[CodeLanguage] = {
+ CodeLanguage.PYTHON3
+ }
+
@classmethod
- def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], preload: str, code: str) -> str:
+ def execute_code(cls,
+ language: Literal['python3', 'javascript', 'jinja2'],
+ preload: str,
+ code: str,
+ dependencies: Optional[list[CodeDependency]] = None) -> str:
"""
Execute code
:param language: code language
@@ -44,13 +81,15 @@ class CodeExecutor:
}
data = {
- 'language': 'python3' if language == 'jinja2' else
- 'nodejs' if language == 'javascript' else
- 'python3' if language == 'python3' else None,
+ 'language': cls.code_language_to_running_language.get(language),
'code': code,
- 'preload': preload
+ 'preload': preload,
+ 'enable_network': True
}
+ if dependencies:
+ data['dependencies'] = [dependency.dict() for dependency in dependencies]
+
try:
response = post(str(url), json=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
if response.status_code == 503:
@@ -60,7 +99,9 @@ class CodeExecutor:
except CodeExecutionException as e:
raise e
except Exception as e:
- raise CodeExecutionException('Failed to execute code, this is likely a network issue, please check if the sandbox service is running')
+ raise CodeExecutionException('Failed to execute code, which is likely a network issue,'
+ ' please check if the sandbox service is running.'
+ f' ( Error: {str(e)} )')
try:
response = response.json()
@@ -78,7 +119,7 @@ class CodeExecutor:
return response.data.stdout
@classmethod
- def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict:
+ def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
"""
Execute code
:param language: code language
@@ -86,21 +127,71 @@ class CodeExecutor:
:param inputs: inputs
:return:
"""
- template_transformer = None
- if language == 'python3':
- template_transformer = PythonTemplateTransformer
- elif language == 'jinja2':
- template_transformer = Jinja2TemplateTransformer
- elif language == 'javascript':
- template_transformer = NodeJsTemplateTransformer
- else:
- raise CodeExecutionException('Unsupported language')
+ template_transformer = cls.code_template_transformers.get(language)
+ if not template_transformer:
+ raise CodeExecutionException(f'Unsupported language {language}')
- runner, preload = template_transformer.transform_caller(code, inputs)
+ runner, preload, dependencies = template_transformer.transform_caller(code, inputs, dependencies)
try:
- response = cls.execute_code(language, preload, runner)
+ response = cls.execute_code(language, preload, runner, dependencies)
except CodeExecutionException as e:
raise e
- return template_transformer.transform_response(response)
\ No newline at end of file
+ return template_transformer.transform_response(response)
+
+ @classmethod
+ def list_dependencies(cls, language: str) -> list[CodeDependency]:
+ if language not in cls.supported_dependencies_languages:
+ return []
+
+ with cls.dependencies_cache_lock:
+ if language in cls.dependencies_cache:
+ # check expiration
+ dependencies = cls.dependencies_cache[language]
+ if dependencies['expiration'] > time.time():
+ return dependencies['data']
+ # remove expired cache
+ del cls.dependencies_cache[language]
+
+ dependencies = cls._get_dependencies(language)
+ with cls.dependencies_cache_lock:
+ cls.dependencies_cache[language] = {
+ 'data': dependencies,
+ 'expiration': time.time() + 60
+ }
+
+ return dependencies
+
+ @classmethod
+ def _get_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]:
+ """
+ List dependencies
+ """
+ url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'dependencies'
+
+ headers = {
+ 'X-Api-Key': CODE_EXECUTION_API_KEY
+ }
+
+ running_language = cls.code_language_to_running_language.get(language)
+ if isinstance(running_language, Enum):
+ running_language = running_language.value
+
+ data = {
+ 'language': running_language,
+ }
+
+ try:
+ response = get(str(url), params=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
+ if response.status_code != 200:
+ raise Exception(f'Failed to list dependencies, got status code {response.status_code}, please check if the sandbox service is running')
+ response = response.json()
+ dependencies = response.get('data', {}).get('dependencies', [])
+ return [
+ CodeDependency(**dependency) for dependency in dependencies
+ if dependency.get('name') not in Python3TemplateTransformer.get_standard_packages()
+ ]
+ except Exception as e:
+ logger.exception(f'Failed to list dependencies: {e}')
+ return []
\ No newline at end of file
diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py
new file mode 100644
index 0000000000..b76e15eeab
--- /dev/null
+++ b/api/core/helper/code_executor/code_node_provider.py
@@ -0,0 +1,55 @@
+from abc import abstractmethod
+
+from pydantic import BaseModel
+
+from core.helper.code_executor.code_executor import CodeExecutor
+
+
+class CodeNodeProvider(BaseModel):
+ @staticmethod
+ @abstractmethod
+ def get_language() -> str:
+ pass
+
+ @classmethod
+ def is_accept_language(cls, language: str) -> bool:
+ return language == cls.get_language()
+
+ @classmethod
+ @abstractmethod
+ def get_default_code(cls) -> str:
+ """
+ get default code in specific programming language for the code node
+ """
+ pass
+
+ @classmethod
+ def get_default_available_packages(cls) -> list[dict]:
+ return [p.dict() for p in CodeExecutor.list_dependencies(cls.get_language())]
+
+ @classmethod
+ def get_default_config(cls) -> dict:
+ return {
+ "type": "code",
+ "config": {
+ "variables": [
+ {
+ "variable": "arg1",
+ "value_selector": []
+ },
+ {
+ "variable": "arg2",
+ "value_selector": []
+ }
+ ],
+ "code_language": cls.get_language(),
+ "code": cls.get_default_code(),
+ "outputs": {
+ "result": {
+ "type": "string",
+ "children": None
+ }
+ }
+ },
+ "available_dependencies": cls.get_default_available_packages(),
+ }
diff --git a/api/core/helper/code_executor/entities.py b/api/core/helper/code_executor/entities.py
new file mode 100644
index 0000000000..cc10288521
--- /dev/null
+++ b/api/core/helper/code_executor/entities.py
@@ -0,0 +1,6 @@
+from pydantic import BaseModel
+
+
+class CodeDependency(BaseModel):
+ name: str
+ version: str
diff --git a/api/core/rerank/__init__.py b/api/core/helper/code_executor/javascript/__init__.py
similarity index 100%
rename from api/core/rerank/__init__.py
rename to api/core/helper/code_executor/javascript/__init__.py
diff --git a/api/core/helper/code_executor/javascript/javascript_code_provider.py b/api/core/helper/code_executor/javascript/javascript_code_provider.py
new file mode 100644
index 0000000000..a157fcc6d1
--- /dev/null
+++ b/api/core/helper/code_executor/javascript/javascript_code_provider.py
@@ -0,0 +1,21 @@
+from textwrap import dedent
+
+from core.helper.code_executor.code_executor import CodeLanguage
+from core.helper.code_executor.code_node_provider import CodeNodeProvider
+
+
+class JavascriptCodeProvider(CodeNodeProvider):
+ @staticmethod
+ def get_language() -> str:
+ return CodeLanguage.JAVASCRIPT
+
+ @classmethod
+ def get_default_code(cls) -> str:
+ return dedent(
+ """
+ function main({arg1, arg2}) {
+ return {
+ result: arg1 + arg2
+ }
+ }
+ """)
diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py
new file mode 100644
index 0000000000..a4d2551972
--- /dev/null
+++ b/api/core/helper/code_executor/javascript/javascript_transformer.py
@@ -0,0 +1,25 @@
+from textwrap import dedent
+
+from core.helper.code_executor.template_transformer import TemplateTransformer
+
+
+class NodeJsTemplateTransformer(TemplateTransformer):
+ @classmethod
+ def get_runner_script(cls) -> str:
+ runner_script = dedent(
+ f"""
+ // declare main function
+ {cls._code_placeholder}
+
+ // decode and prepare input object
+ var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
+
+ // execute main function
+ var output_obj = main(inputs_obj)
+
+ // convert output to json and print
+ var output_json = JSON.stringify(output_obj)
+ var result = `<>${{output_json}}<>`
+ console.log(result)
+ """)
+ return runner_script
diff --git a/api/core/helper/code_executor/javascript_transformer.py b/api/core/helper/code_executor/javascript_transformer.py
deleted file mode 100644
index 29b8e06e86..0000000000
--- a/api/core/helper/code_executor/javascript_transformer.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import json
-import re
-
-from core.helper.code_executor.template_transformer import TemplateTransformer
-
-NODEJS_RUNNER = """// declare main function here
-{{code}}
-
-// execute main function, and return the result
-// inputs is a dict, unstructured inputs
-output = main({{inputs}})
-
-// convert output to json and print
-output = JSON.stringify(output)
-
-result = `<>${output}<>`
-
-console.log(result)
-"""
-
-NODEJS_PRELOAD = """"""
-
-class NodeJsTemplateTransformer(TemplateTransformer):
- @classmethod
- def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
- """
- Transform code to python runner
- :param code: code
- :param inputs: inputs
- :return:
- """
-
- # transform inputs to json string
- inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False)
-
- # replace code and inputs
- runner = NODEJS_RUNNER.replace('{{code}}', code)
- runner = runner.replace('{{inputs}}', inputs_str)
-
- return runner, NODEJS_PRELOAD
-
- @classmethod
- def transform_response(cls, response: str) -> dict:
- """
- Transform response to dict
- :param response: response
- :return:
- """
- # extract result
- result = re.search(r'<>(.*)<>', response, re.DOTALL)
- if not result:
- raise ValueError('Failed to parse result')
- result = result.group(1)
- return json.loads(result)
diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/helper/code_executor/jinja2/__init__.py
similarity index 100%
rename from api/core/workflow/nodes/variable_assigner/__init__.py
rename to api/core/helper/code_executor/jinja2/__init__.py
diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py
new file mode 100644
index 0000000000..63f48a56e2
--- /dev/null
+++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py
@@ -0,0 +1,17 @@
+from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
+
+
+class Jinja2Formatter:
+ @classmethod
+ def format(cls, template: str, inputs: str) -> str:
+ """
+ Format template
+ :param template: template
+ :param inputs: inputs
+ :return:
+ """
+ result = CodeExecutor.execute_workflow_code_template(
+ language=CodeLanguage.JINJA2, code=template, inputs=inputs
+ )
+
+ return result['result']
diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py
new file mode 100644
index 0000000000..a8f8095d52
--- /dev/null
+++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py
@@ -0,0 +1,64 @@
+from textwrap import dedent
+
+from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
+from core.helper.code_executor.template_transformer import TemplateTransformer
+
+
+class Jinja2TemplateTransformer(TemplateTransformer):
+ @classmethod
+ def get_standard_packages(cls) -> set[str]:
+ return {'jinja2'} | Python3TemplateTransformer.get_standard_packages()
+
+ @classmethod
+ def transform_response(cls, response: str) -> dict:
+ """
+ Transform response to dict
+ :param response: response
+ :return:
+ """
+ return {
+ 'result': cls.extract_result_str_from_response(response)
+ }
+
+ @classmethod
+ def get_runner_script(cls) -> str:
+ runner_script = dedent(f"""
+ # declare main function
+ def main(**inputs):
+ import jinja2
+ template = jinja2.Template('''{cls._code_placeholder}''')
+ return template.render(**inputs)
+
+ import json
+ from base64 import b64decode
+
+ # decode and prepare input dict
+ inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
+
+ # execute main function
+ output = main(**inputs_obj)
+
+ # convert output and print
+ result = f'''<>{{output}}<>'''
+ print(result)
+
+ """)
+ return runner_script
+
+ @classmethod
+ def get_preload_script(cls) -> str:
+ preload_script = dedent("""
+ import jinja2
+ from base64 import b64decode
+
+ def _jinja2_preload_():
+ # prepare jinja2 environment, load template and render before to avoid sandbox issue
+ template = jinja2.Template('{{s}}')
+ template.render(s='a')
+
+ if __name__ == '__main__':
+ _jinja2_preload_()
+
+ """)
+
+ return preload_script
diff --git a/api/core/helper/code_executor/jinja2_transformer.py b/api/core/helper/code_executor/jinja2_transformer.py
deleted file mode 100644
index 27a3579493..0000000000
--- a/api/core/helper/code_executor/jinja2_transformer.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import json
-import re
-from base64 import b64encode
-
-from core.helper.code_executor.template_transformer import TemplateTransformer
-
-PYTHON_RUNNER = """
-import jinja2
-from json import loads
-from base64 import b64decode
-
-template = jinja2.Template('''{{code}}''')
-
-def main(**inputs):
- return template.render(**inputs)
-
-# execute main function, and return the result
-inputs = b64decode('{{inputs}}').decode('utf-8')
-output = main(**loads(inputs))
-
-result = f'''<>{output}<>'''
-
-print(result)
-
-"""
-
-JINJA2_PRELOAD_TEMPLATE = """{% set fruits = ['Apple'] %}
-{{ 'a' }}
-{% for fruit in fruits %}
-
{{ fruit }}
-{% endfor %}
-{% if fruits|length > 1 %}
-1
-{% endif %}
-{% for i in range(5) %}
- {% if i == 3 %}{{ i }}{% else %}{% endif %}
-{% endfor %}
- {% for i in range(3) %}
- {{ i + 1 }}
- {% endfor %}
-{% macro say_hello() %}a{{ 'b' }}{% endmacro %}
-{{ s }}{{ say_hello() }}"""
-
-JINJA2_PRELOAD = f"""
-import jinja2
-from base64 import b64decode
-
-def _jinja2_preload_():
- # prepare jinja2 environment, load template and render before to avoid sandbox issue
- template = jinja2.Template('''{JINJA2_PRELOAD_TEMPLATE}''')
- template.render(s='a')
-
-if __name__ == '__main__':
- _jinja2_preload_()
-
-"""
-
-
-class Jinja2TemplateTransformer(TemplateTransformer):
- @classmethod
- def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
- """
- Transform code to python runner
- :param code: code
- :param inputs: inputs
- :return:
- """
-
- inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
-
- # transform jinja2 template to python code
- runner = PYTHON_RUNNER.replace('{{code}}', code)
- runner = runner.replace('{{inputs}}', inputs_str)
-
- return runner, JINJA2_PRELOAD
-
- @classmethod
- def transform_response(cls, response: str) -> dict:
- """
- Transform response to dict
- :param response: response
- :return:
- """
- # extract result
- result = re.search(r'<>(.*)<>', response, re.DOTALL)
- if not result:
- raise ValueError('Failed to parse result')
- result = result.group(1)
-
- return {
- 'result': result
- }
diff --git a/api/core/application_manager.py b/api/core/helper/code_executor/python3/__init__.py
similarity index 100%
rename from api/core/application_manager.py
rename to api/core/helper/code_executor/python3/__init__.py
diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py
new file mode 100644
index 0000000000..efcb8a9d1e
--- /dev/null
+++ b/api/core/helper/code_executor/python3/python3_code_provider.py
@@ -0,0 +1,20 @@
+from textwrap import dedent
+
+from core.helper.code_executor.code_executor import CodeLanguage
+from core.helper.code_executor.code_node_provider import CodeNodeProvider
+
+
+class Python3CodeProvider(CodeNodeProvider):
+ @staticmethod
+ def get_language() -> str:
+ return CodeLanguage.PYTHON3
+
+ @classmethod
+ def get_default_code(cls) -> str:
+ return dedent(
+ """
+ def main(arg1: int, arg2: int) -> dict:
+ return {
+ "result": arg1 + arg2,
+ }
+ """)
diff --git a/api/core/helper/code_executor/python3/python3_transformer.py b/api/core/helper/code_executor/python3/python3_transformer.py
new file mode 100644
index 0000000000..4a5fa35093
--- /dev/null
+++ b/api/core/helper/code_executor/python3/python3_transformer.py
@@ -0,0 +1,51 @@
+from textwrap import dedent
+
+from core.helper.code_executor.template_transformer import TemplateTransformer
+
+
+class Python3TemplateTransformer(TemplateTransformer):
+ @classmethod
+ def get_standard_packages(cls) -> set[str]:
+ return {
+ 'base64',
+ 'binascii',
+ 'collections',
+ 'datetime',
+ 'functools',
+ 'hashlib',
+ 'hmac',
+ 'itertools',
+ 'json',
+ 'math',
+ 'operator',
+ 'os',
+ 'random',
+ 're',
+ 'string',
+ 'sys',
+ 'time',
+ 'traceback',
+ 'uuid',
+ }
+
+ @classmethod
+ def get_runner_script(cls) -> str:
+ runner_script = dedent(f"""
+ # declare main function
+ {cls._code_placeholder}
+
+ import json
+ from base64 import b64decode
+
+ # decode and prepare input dict
+ inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
+
+ # execute main function
+ output_obj = main(**inputs_obj)
+
+ # convert output to json and print
+ output_json = json.dumps(output_obj, indent=4)
+ result = f'''<>{{output_json}}<>'''
+ print(result)
+ """)
+ return runner_script
diff --git a/api/core/helper/code_executor/python_transformer.py b/api/core/helper/code_executor/python_transformer.py
deleted file mode 100644
index f44acbb9bf..0000000000
--- a/api/core/helper/code_executor/python_transformer.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import json
-import re
-from base64 import b64encode
-
-from core.helper.code_executor.template_transformer import TemplateTransformer
-
-PYTHON_RUNNER = """# declare main function here
-{{code}}
-
-from json import loads, dumps
-from base64 import b64decode
-
-# execute main function, and return the result
-# inputs is a dict, and it
-inputs = b64decode('{{inputs}}').decode('utf-8')
-output = main(**json.loads(inputs))
-
-# convert output to json and print
-output = dumps(output, indent=4)
-
-result = f'''<>
-{output}
-<>'''
-
-print(result)
-"""
-
-PYTHON_PRELOAD = """
-# prepare general imports
-import json
-import datetime
-import math
-import random
-import re
-import string
-import sys
-import time
-import traceback
-import uuid
-import os
-import base64
-import hashlib
-import hmac
-import binascii
-import collections
-import functools
-import operator
-import itertools
-"""
-
-class PythonTemplateTransformer(TemplateTransformer):
- @classmethod
- def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
- """
- Transform code to python runner
- :param code: code
- :param inputs: inputs
- :return:
- """
-
- # transform inputs to json string
- inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
-
- # replace code and inputs
- runner = PYTHON_RUNNER.replace('{{code}}', code)
- runner = runner.replace('{{inputs}}', inputs_str)
-
- return runner, PYTHON_PRELOAD
-
- @classmethod
- def transform_response(cls, response: str) -> dict:
- """
- Transform response to dict
- :param response: response
- :return:
- """
- # extract result
- result = re.search(r'<>(.*?)<>', response, re.DOTALL)
- if not result:
- raise ValueError('Failed to parse result')
- result = result.group(1)
- return json.loads(result)
diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py
index c3564afd04..39af803f6e 100644
--- a/api/core/helper/code_executor/template_transformer.py
+++ b/api/core/helper/code_executor/template_transformer.py
@@ -1,24 +1,87 @@
+import json
+import re
from abc import ABC, abstractmethod
+from base64 import b64encode
+from typing import Optional
+
+from pydantic import BaseModel
+
+from core.helper.code_executor.entities import CodeDependency
-class TemplateTransformer(ABC):
+class TemplateTransformer(ABC, BaseModel):
+ _code_placeholder: str = '{{code}}'
+ _inputs_placeholder: str = '{{inputs}}'
+ _result_tag: str = '<>'
+
@classmethod
- @abstractmethod
- def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
+ def get_standard_packages(cls) -> set[str]:
+ return set()
+
+ @classmethod
+ def transform_caller(cls, code: str, inputs: dict,
+ dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
"""
Transform code to python runner
:param code: code
:param inputs: inputs
:return: runner, preload
"""
- pass
-
+ runner_script = cls.assemble_runner_script(code, inputs)
+ preload_script = cls.get_preload_script()
+
+ packages = dependencies or []
+ standard_packages = cls.get_standard_packages()
+ for package in standard_packages:
+ if package not in packages:
+ packages.append(CodeDependency(name=package, version=''))
+ packages = list({dep.name: dep for dep in packages if dep.name}.values())
+
+ return runner_script, preload_script, packages
+
+ @classmethod
+ def extract_result_str_from_response(cls, response: str) -> str:
+ result = re.search(rf'{cls._result_tag}(.*){cls._result_tag}', response, re.DOTALL)
+ if not result:
+ raise ValueError('Failed to parse result')
+ result = result.group(1)
+ return result
+
@classmethod
- @abstractmethod
def transform_response(cls, response: str) -> dict:
"""
Transform response to dict
:param response: response
:return:
"""
- pass
\ No newline at end of file
+ return json.loads(cls.extract_result_str_from_response(response))
+
+ @classmethod
+ @abstractmethod
+ def get_runner_script(cls) -> str:
+ """
+ Get runner script
+ """
+ pass
+
+ @classmethod
+ def serialize_inputs(cls, inputs: dict) -> str:
+ inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode()
+ input_base64_encoded = b64encode(inputs_json_str).decode('utf-8')
+ return input_base64_encoded
+
+ @classmethod
+ def assemble_runner_script(cls, code: str, inputs: dict) -> str:
+ # assemble runner script
+ script = cls.get_runner_script()
+ script = script.replace(cls._code_placeholder, code)
+ inputs_str = cls.serialize_inputs(inputs)
+ script = script.replace(cls._inputs_placeholder, inputs_str)
+ return script
+
+ @classmethod
+ def get_preload_script(cls) -> str:
+ """
+ Get preload script
+ """
+ return ''
diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py
index 81e589f65b..29cb4acc7d 100644
--- a/api/core/helper/model_provider_cache.py
+++ b/api/core/helper/model_provider_cache.py
@@ -9,6 +9,7 @@ from extensions.ext_redis import redis_client
class ProviderCredentialsCacheType(Enum):
PROVIDER = "provider"
MODEL = "provider_model"
+ LOAD_BALANCING_MODEL = "load_balancing_provider_model"
class ProviderCredentialsCache:
diff --git a/api/core/utils/module_import_helper.py b/api/core/helper/module_import_helper.py
similarity index 100%
rename from api/core/utils/module_import_helper.py
rename to api/core/helper/module_import_helper.py
diff --git a/api/core/utils/position_helper.py b/api/core/helper/position_helper.py
similarity index 75%
rename from api/core/utils/position_helper.py
rename to api/core/helper/position_helper.py
index e038390e09..689ab194a7 100644
--- a/api/core/utils/position_helper.py
+++ b/api/core/helper/position_helper.py
@@ -1,10 +1,9 @@
-import logging
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, AnyStr
-import yaml
+from core.tools.utils.yaml_utils import load_yaml_file
def get_position_map(
@@ -17,21 +16,15 @@ def get_position_map(
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
- try:
- position_file_name = os.path.join(folder_path, file_name)
- if not os.path.exists(position_file_name):
- return {}
-
- with open(position_file_name, encoding='utf-8') as f:
- positions = yaml.safe_load(f)
- position_map = {}
- for index, name in enumerate(positions):
- if name and isinstance(name, str):
- position_map[name.strip()] = index
- return position_map
- except:
- logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
- return {}
+ position_file_name = os.path.join(folder_path, file_name)
+ positions = load_yaml_file(position_file_name, ignore_error=True)
+ position_map = {}
+ index = 0
+ for _, name in enumerate(positions):
+ if name and isinstance(name, str):
+ position_map[name.strip()] = index
+ index += 1
+ return position_map
def sort_by_position_map(
diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py
index c44d4717e6..276c8a34e7 100644
--- a/api/core/helper/ssrf_proxy.py
+++ b/api/core/helper/ssrf_proxy.py
@@ -42,6 +42,20 @@ def delete(url, *args, **kwargs):
if kwargs['follow_redirects']:
kwargs['allow_redirects'] = kwargs['follow_redirects']
kwargs.pop('follow_redirects')
+ if 'timeout' in kwargs:
+ timeout = kwargs['timeout']
+ if timeout is None:
+ kwargs.pop('timeout')
+ elif isinstance(timeout, tuple):
+ # check length of tuple
+ if len(timeout) == 2:
+ kwargs['timeout'] = timeout
+ elif len(timeout) == 1:
+ kwargs['timeout'] = timeout[0]
+ elif len(timeout) > 2:
+ kwargs['timeout'] = (timeout[0], timeout[1])
+ else:
+ kwargs['timeout'] = (timeout, timeout)
return _delete(url=url, *args, proxies=requests_proxies, **kwargs)
def head(url, *args, **kwargs):
diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py
index 51e77393d6..7fa1d7d4b9 100644
--- a/api/core/indexing_runner.py
+++ b/api/core/indexing_runner.py
@@ -12,7 +12,6 @@ from flask import Flask, current_app
from flask_login import current_user
from sqlalchemy.orm.exc import ObjectDeletedError
-from core.docstore.dataset_docstore import DatasetDocumentStore
from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager
@@ -20,12 +19,16 @@ from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.keyword.keyword_factory import Keyword
+from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
-from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
-from core.splitter.text_splitter import TextSplitter
+from core.rag.splitter.fixed_text_splitter import (
+ EnhanceRecursiveCharacterTextSplitter,
+ FixedRecursiveCharacterTextSplitter,
+)
+from core.rag.splitter.text_splitter import TextSplitter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@@ -283,11 +286,7 @@ class IndexingRunner:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model_instance:
- embedding_model_type_instance = embedding_model_instance.model_type_instance
- embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
- tokens += embedding_model_type_instance.get_num_tokens(
- model=embedding_model_instance.model,
- credentials=embedding_model_instance.credentials,
+ tokens += embedding_model_instance.get_text_embedding_num_tokens(
texts=[self.filter_string(document.page_content)]
)
@@ -411,14 +410,15 @@ class IndexingRunner:
# The user-defined segmentation rule
rules = json.loads(processing_rule.rules)
segmentation = rules["segmentation"]
- if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
- raise ValueError("Custom segment length should be between 50 and 1000.")
+ max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH'])
+ if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
+ raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
separator = segmentation["separator"]
if separator:
separator = separator.replace('\\n', '\n')
- if 'chunk_overlap' in segmentation and segmentation['chunk_overlap']:
+ if segmentation.get('chunk_overlap'):
chunk_overlap = segmentation['chunk_overlap']
else:
chunk_overlap = 0
@@ -427,7 +427,7 @@ class IndexingRunner:
chunk_size=segmentation["max_tokens"],
chunk_overlap=chunk_overlap,
fixed_separator=separator,
- separators=["\n\n", "。", ".", " ", ""],
+ separators=["\n\n", "。", ". ", " ", ""],
embedding_model_instance=embedding_model_instance
)
else:
@@ -435,7 +435,7 @@ class IndexingRunner:
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
- separators=["\n\n", "。", ".", " ", ""],
+ separators=["\n\n", "。", ". ", " ", ""],
embedding_model_instance=embedding_model_instance
)
@@ -654,10 +654,6 @@ class IndexingRunner:
tokens = 0
chunk_size = 10
- embedding_model_type_instance = None
- if embedding_model_instance:
- embedding_model_type_instance = embedding_model_instance.model_type_instance
- embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
# create keyword index
create_keyword_thread = threading.Thread(target=self._process_keyword_index,
args=(current_app._get_current_object(),
@@ -670,8 +666,7 @@ class IndexingRunner:
chunk_documents = documents[i:i + chunk_size]
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
chunk_documents, dataset,
- dataset_document, embedding_model_instance,
- embedding_model_type_instance))
+ dataset_document, embedding_model_instance))
for future in futures:
tokens += future.result()
@@ -712,7 +707,7 @@ class IndexingRunner:
db.session.commit()
def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
- embedding_model_instance, embedding_model_type_instance):
+ embedding_model_instance):
with flask_app.app_context():
# check document is paused
self._check_document_paused_status(dataset_document.id)
@@ -720,9 +715,7 @@ class IndexingRunner:
tokens = 0
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
tokens += sum(
- embedding_model_type_instance.get_num_tokens(
- embedding_model_instance.model,
- embedding_model_instance.credentials,
+ embedding_model_instance.get_text_embedding_num_tokens(
[document.page_content]
)
for document in chunk_documents
diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py
index cd0b2508d4..6b53104c70 100644
--- a/api/core/memory/token_buffer_memory.py
+++ b/api/core/memory/token_buffer_memory.py
@@ -9,8 +9,6 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
-from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.model_providers import model_provider_factory
from extensions.ext_database import db
from models.model import AppMode, Conversation, Message
@@ -78,12 +76,7 @@ class TokenBufferMemory:
return []
# prune the chat message if it exceeds the max token limit
- provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider)
- model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
-
- curr_message_tokens = model_type_instance.get_num_tokens(
- self.model_instance.model,
- self.model_instance.credentials,
+ curr_message_tokens = self.model_instance.get_llm_num_tokens(
prompt_messages
)
@@ -91,9 +84,7 @@ class TokenBufferMemory:
pruned_memory = []
while curr_message_tokens > max_token_limit and prompt_messages:
pruned_memory.append(prompt_messages.pop(0))
- curr_message_tokens = model_type_instance.get_num_tokens(
- self.model_instance.model,
- self.model_instance.credentials,
+ curr_message_tokens = self.model_instance.get_llm_num_tokens(
prompt_messages
)
diff --git a/api/core/model_manager.py b/api/core/model_manager.py
index 8c06339927..8da8442e60 100644
--- a/api/core/model_manager.py
+++ b/api/core/model_manager.py
@@ -1,7 +1,10 @@
+import logging
+import os
from collections.abc import Generator
from typing import IO, Optional, Union, cast
-from core.entities.provider_configuration import ProviderModelBundle
+from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
+from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult
@@ -9,6 +12,7 @@ from core.model_runtime.entities.message_entities import PromptMessage, PromptMe
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
@@ -16,6 +20,10 @@ from core.model_runtime.model_providers.__base.speech2text_model import Speech2T
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.provider_manager import ProviderManager
+from extensions.ext_redis import redis_client
+from models.provider import ProviderType
+
+logger = logging.getLogger(__name__)
class ModelInstance:
@@ -29,6 +37,12 @@ class ModelInstance:
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self.provider_model_bundle.model_type_instance
+ self.load_balancing_manager = self._get_load_balancing_manager(
+ configuration=provider_model_bundle.configuration,
+ model_type=provider_model_bundle.model_type_instance.model_type,
+ model=model,
+ credentials=self.credentials
+ )
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
"""
@@ -37,8 +51,10 @@ class ModelInstance:
:param model: model name
:return:
"""
- credentials = provider_model_bundle.configuration.get_current_credentials(
- model_type=provider_model_bundle.model_type_instance.model_type,
+ configuration = provider_model_bundle.configuration
+ model_type = provider_model_bundle.model_type_instance.model_type
+ credentials = configuration.get_current_credentials(
+ model_type=model_type,
model=model
)
@@ -47,6 +63,43 @@ class ModelInstance:
return credentials
+ def _get_load_balancing_manager(self, configuration: ProviderConfiguration,
+ model_type: ModelType,
+ model: str,
+ credentials: dict) -> Optional["LBModelManager"]:
+ """
+ Get load balancing model credentials
+ :param configuration: provider configuration
+ :param model_type: model type
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ if configuration.model_settings and configuration.using_provider_type == ProviderType.CUSTOM:
+ current_model_setting = None
+ # check if model is disabled by admin
+ for model_setting in configuration.model_settings:
+ if (model_setting.model_type == model_type
+ and model_setting.model == model):
+ current_model_setting = model_setting
+ break
+
+ # check if load balancing is enabled
+ if current_model_setting and current_model_setting.load_balancing_configs:
+ # use load balancing proxy to choose credentials
+ lb_model_manager = LBModelManager(
+ tenant_id=configuration.tenant_id,
+ provider=configuration.provider.provider,
+ model_type=model_type,
+ model=model,
+ load_balancing_configs=current_model_setting.load_balancing_configs,
+ managed_credentials=credentials if configuration.custom_configuration.provider else None
+ )
+
+ return lb_model_manager
+
+ return None
+
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
@@ -67,7 +120,8 @@ class ModelInstance:
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
- return self.model_type_instance.invoke(
+ return self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
@@ -79,6 +133,27 @@ class ModelInstance:
callbacks=callbacks
)
+ def get_llm_num_tokens(self, prompt_messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None) -> int:
+ """
+ Get number of tokens for llm
+
+ :param prompt_messages: prompt messages
+ :param tools: tools for tool calling
+ :return:
+ """
+ if not isinstance(self.model_type_instance, LargeLanguageModel):
+ raise Exception("Model type instance is not LargeLanguageModel")
+
+ self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
+ return self._round_robin_invoke(
+ function=self.model_type_instance.get_num_tokens,
+ model=self.model,
+ credentials=self.credentials,
+ prompt_messages=prompt_messages,
+ tools=tools
+ )
+
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
@@ -92,13 +167,32 @@ class ModelInstance:
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
- return self.model_type_instance.invoke(
+ return self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
texts=texts,
user=user
)
+ def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
+ """
+ Get number of tokens for text embedding
+
+ :param texts: texts to embed
+ :return:
+ """
+ if not isinstance(self.model_type_instance, TextEmbeddingModel):
+ raise Exception("Model type instance is not TextEmbeddingModel")
+
+ self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
+ return self._round_robin_invoke(
+ function=self.model_type_instance.get_num_tokens,
+ model=self.model,
+ credentials=self.credentials,
+ texts=texts
+ )
+
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None) \
@@ -117,7 +211,8 @@ class ModelInstance:
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
- return self.model_type_instance.invoke(
+ return self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
query=query,
@@ -140,7 +235,8 @@ class ModelInstance:
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
- return self.model_type_instance.invoke(
+ return self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
text=text,
@@ -160,7 +256,8 @@ class ModelInstance:
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
- return self.model_type_instance.invoke(
+ return self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
file=file,
@@ -183,7 +280,8 @@ class ModelInstance:
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
- return self.model_type_instance.invoke(
+ return self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
content_text=content_text,
@@ -193,6 +291,43 @@ class ModelInstance:
streaming=streaming
)
+ def _round_robin_invoke(self, function: callable, *args, **kwargs):
+ """
+ Round-robin invoke
+ :param function: function to invoke
+ :param args: function args
+ :param kwargs: function kwargs
+ :return:
+ """
+ if not self.load_balancing_manager:
+ return function(*args, **kwargs)
+
+ last_exception = None
+ while True:
+ lb_config = self.load_balancing_manager.fetch_next()
+ if not lb_config:
+ if not last_exception:
+ raise ProviderTokenNotInitError("Model credentials is not initialized.")
+ else:
+ raise last_exception
+
+ try:
+ if 'credentials' in kwargs:
+ del kwargs['credentials']
+ return function(*args, **kwargs, credentials=lb_config.credentials)
+ except InvokeRateLimitError as e:
+ # expire in 60 seconds
+ self.load_balancing_manager.cooldown(lb_config, expire=60)
+ last_exception = e
+ continue
+ except (InvokeAuthorizationError, InvokeConnectionError) as e:
+ # expire in 10 seconds
+ self.load_balancing_manager.cooldown(lb_config, expire=10)
+ last_exception = e
+ continue
+ except Exception as e:
+ raise e
+
def get_tts_voices(self, language: str) -> list:
"""
Invoke large language tts model voices
@@ -226,6 +361,7 @@ class ModelManager:
"""
if not provider:
return self.get_default_model_instance(tenant_id, model_type)
+
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=provider,
@@ -255,3 +391,141 @@ class ModelManager:
model_type=model_type,
model=default_model_entity.model
)
+
+
+class LBModelManager:
+ def __init__(self, tenant_id: str,
+ provider: str,
+ model_type: ModelType,
+ model: str,
+ load_balancing_configs: list[ModelLoadBalancingConfiguration],
+ managed_credentials: Optional[dict] = None) -> None:
+ """
+ Load balancing model manager
+ :param load_balancing_configs: all load balancing configurations
+ :param managed_credentials: credentials if load balancing configuration name is __inherit__
+ """
+ self._tenant_id = tenant_id
+ self._provider = provider
+ self._model_type = model_type
+ self._model = model
+ self._load_balancing_configs = load_balancing_configs
+
+ for load_balancing_config in self._load_balancing_configs:
+ if load_balancing_config.name == "__inherit__":
+ if not managed_credentials:
+ # remove __inherit__ if managed credentials is not provided
+ self._load_balancing_configs.remove(load_balancing_config)
+ else:
+ load_balancing_config.credentials = managed_credentials
+
+ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]:
+ """
+ Get next model load balancing config
+ Strategy: Round Robin
+ :return:
+ """
+ cache_key = "model_lb_index:{}:{}:{}:{}".format(
+ self._tenant_id,
+ self._provider,
+ self._model_type.value,
+ self._model
+ )
+
+ cooldown_load_balancing_configs = []
+ max_index = len(self._load_balancing_configs)
+
+ while True:
+ current_index = redis_client.incr(cache_key)
+ if current_index >= 10000000:
+ current_index = 1
+ redis_client.set(cache_key, current_index)
+
+ redis_client.expire(cache_key, 3600)
+ if current_index > max_index:
+ current_index = current_index % max_index
+
+ real_index = current_index - 1
+ if real_index > max_index:
+ real_index = 0
+
+ config = self._load_balancing_configs[real_index]
+
+ if self.in_cooldown(config):
+ cooldown_load_balancing_configs.append(config)
+ if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs):
+ # all configs are in cooldown
+ return None
+
+ continue
+
+ if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
+ logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n"
+ f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
+ f"model_type: {self._model_type.value}\nmodel: {self._model}")
+
+ return config
+
+ return None
+
+ def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
+ """
+ Cooldown model load balancing config
+ :param config: model load balancing config
+ :param expire: cooldown time
+ :return:
+ """
+ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
+ self._tenant_id,
+ self._provider,
+ self._model_type.value,
+ self._model,
+ config.id
+ )
+
+ redis_client.setex(cooldown_cache_key, expire, 'true')
+
+ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
+ """
+ Check if model load balancing config is in cooldown
+ :param config: model load balancing config
+ :return:
+ """
+ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
+ self._tenant_id,
+ self._provider,
+ self._model_type.value,
+ self._model,
+ config.id
+ )
+
+ return redis_client.exists(cooldown_cache_key)
+
+ @classmethod
+ def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
+ provider: str,
+ model_type: ModelType,
+ model: str,
+ config_id: str) -> tuple[bool, int]:
+ """
+ Get model load balancing config is in cooldown and ttl
+ :param tenant_id: workspace id
+ :param provider: provider name
+ :param model_type: model type
+ :param model: model name
+ :param config_id: model load balancing config id
+ :return:
+ """
+ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
+ tenant_id,
+ provider,
+ model_type.value,
+ model,
+ config_id
+ )
+
+ ttl = redis_client.ttl(cooldown_cache_key)
+ if ttl == -2:
+ return False, 0
+
+ return True, ttl
diff --git a/api/core/model_runtime/README.md b/api/core/model_runtime/README.md
index d7748a8c3c..b5de7ad412 100644
--- a/api/core/model_runtime/README.md
+++ b/api/core/model_runtime/README.md
@@ -20,7 +20,7 @@ This module provides the interface for invoking and authenticating various model

- Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./schema.md).
+ Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
- Selectable model list display
diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py
index 51af9786fd..bba004a32a 100644
--- a/api/core/model_runtime/callbacks/base_callback.py
+++ b/api/core/model_runtime/callbacks/base_callback.py
@@ -1,4 +1,3 @@
-from abc import ABC
from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@@ -14,7 +13,7 @@ _TEXT_COLOR_MAPPING = {
}
-class Callback(ABC):
+class Callback:
"""
Base class for callbacks.
Only for LLM.
diff --git a/api/core/model_runtime/docs/en_US/schema.md b/api/core/model_runtime/docs/en_US/schema.md
index 68f66330cd..2e55d05b0f 100644
--- a/api/core/model_runtime/docs/en_US/schema.md
+++ b/api/core/model_runtime/docs/en_US/schema.md
@@ -51,7 +51,7 @@
- `voices` (list) List of available voice.(available for model type `tts`)
- `mode` (string) voice model.(available for model type `tts`)
- `name` (string) voice model display name.(available for model type `tts`)
- - `lanuage` (string) the voice model supports languages.(available for model type `tts`)
+ - `language` (string) the voice model supports languages.(available for model type `tts`)
- `word_limit` (int) Single conversion word limit, paragraphwise by default(available for model type `tts`)
- `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`)
- `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`)
diff --git a/api/core/model_runtime/docs/zh_Hans/schema.md b/api/core/model_runtime/docs/zh_Hans/schema.md
index fd672993bb..f40a3f8698 100644
--- a/api/core/model_runtime/docs/zh_Hans/schema.md
+++ b/api/core/model_runtime/docs/zh_Hans/schema.md
@@ -52,7 +52,7 @@
- `voices` (list) 可选音色列表。
- `mode` (string) 音色模型。(模型类型 `tts` 可用)
- `name` (string) 音色模型显示名称。(模型类型 `tts` 可用)
- - `lanuage` (string) 音色模型支持语言。(模型类型 `tts` 可用)
+ - `language` (string) 音色模型支持语言。(模型类型 `tts` 可用)
- `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用)
- `audio_type` (string) 支持音频文件扩展格式,如:mp3,wav(模型类型 `tts` 可用)
- `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用)
diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py
index 34a7375493..919e72554c 100644
--- a/api/core/model_runtime/model_providers/__base/ai_model.py
+++ b/api/core/model_runtime/model_providers/__base/ai_model.py
@@ -3,8 +3,7 @@ import os
from abc import ABC, abstractmethod
from typing import Optional
-import yaml
-
+from core.helper.position_helper import get_position_map, sort_by_position_map
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
@@ -18,7 +17,7 @@ from core.model_runtime.entities.model_entities import (
)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
-from core.utils.position_helper import get_position_map, sort_by_position_map
+from core.tools.utils.yaml_utils import load_yaml_file
class AIModel(ABC):
@@ -154,8 +153,7 @@ class AIModel(ABC):
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file
- with open(model_schema_yaml_path, encoding='utf-8') as f:
- yaml_data = yaml.safe_load(f)
+ yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
new_parameter_rules = []
for parameter_rule in yaml_data.get('parameter_rules', []):
diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py
index 7c839a9672..a893d023c0 100644
--- a/api/core/model_runtime/model_providers/__base/model_provider.py
+++ b/api/core/model_runtime/model_providers/__base/model_provider.py
@@ -1,12 +1,11 @@
import os
from abc import ABC, abstractmethod
-import yaml
-
+from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
-from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
+from core.tools.utils.yaml_utils import load_yaml_file
class ModelProvider(ABC):
@@ -44,10 +43,7 @@ class ModelProvider(ABC):
# read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
- yaml_data = {}
- if os.path.exists(yaml_path):
- with open(yaml_path, encoding='utf-8') as f:
- yaml_data = yaml.safe_load(f)
+ yaml_data = load_yaml_file(yaml_path, ignore_error=True)
try:
# yaml_data to entity
diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml
index c06f122984..b483303cad 100644
--- a/api/core/model_runtime/model_providers/_position.yaml
+++ b/api/core/model_runtime/model_providers/_position.yaml
@@ -2,7 +2,9 @@
- anthropic
- azure_openai
- google
+- vertex_ai
- nvidia
+- nvidia_nim
- cohere
- bedrock
- togetherai
@@ -26,4 +28,6 @@
- yi
- openllm
- localai
+- volcengine_maas
- openai_api_compatible
+- deepseek
diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml
index 0dedb2ef38..cb2af1308a 100644
--- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml
+++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml
@@ -6,6 +6,7 @@ features:
- agent-thought
- vision
- tool-call
+ - stream-tool-call
model_properties:
mode: chat
context_size: 200000
diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml
index 60e56452eb..101f54c3f8 100644
--- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml
+++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml
@@ -6,6 +6,7 @@ features:
- agent-thought
- vision
- tool-call
+ - stream-tool-call
model_properties:
mode: chat
context_size: 200000
diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml
index 08c8375d45..daf55553f8 100644
--- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml
+++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml
@@ -6,6 +6,7 @@ features:
- agent-thought
- vision
- tool-call
+ - stream-tool-call
model_properties:
mode: chat
context_size: 200000
diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py
index e0e3514fd1..fbc0b722b1 100644
--- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py
+++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py
@@ -146,7 +146,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"""
Code block mode wrapper for invoking large language model
"""
- if 'response_format' in model_parameters and model_parameters['response_format']:
+ if model_parameters.get('response_format'):
stop = stop or []
# chat model
self._transform_chat_json_prompts(
@@ -324,10 +324,32 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
output_tokens = 0
finish_reason = None
index = 0
+
+ tool_calls: list[AssistantPromptMessage.ToolCall] = []
+
for chunk in response:
if isinstance(chunk, MessageStartEvent):
- return_model = chunk.message.model
- input_tokens = chunk.message.usage.input_tokens
+ if hasattr(chunk, 'content_block'):
+ content_block = chunk.content_block
+ if isinstance(content_block, dict):
+ if content_block.get('type') == 'tool_use':
+ tool_call = AssistantPromptMessage.ToolCall(
+ id=content_block.get('id'),
+ type='function',
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+ name=content_block.get('name'),
+ arguments=''
+ )
+ )
+ tool_calls.append(tool_call)
+ elif hasattr(chunk, 'delta'):
+ delta = chunk.delta
+ if isinstance(delta, dict) and len(tool_calls) > 0:
+ if delta.get('type') == 'input_json_delta':
+ tool_calls[-1].function.arguments += delta.get('partial_json', '')
+ elif chunk.message:
+ return_model = chunk.message.model
+ input_tokens = chunk.message.usage.input_tokens
elif isinstance(chunk, MessageDeltaEvent):
output_tokens = chunk.usage.output_tokens
finish_reason = chunk.delta.stop_reason
@@ -335,13 +357,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
# transform usage
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
+ # transform empty tool call arguments to {}
+ for tool_call in tool_calls:
+ if not tool_call.function.arguments:
+ tool_call.function.arguments = '{}'
+
yield LLMResultChunk(
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index + 1,
message=AssistantPromptMessage(
- content=''
+ content='',
+ tool_calls=tool_calls
),
finish_reason=finish_reason,
usage=usage
@@ -380,7 +408,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"max_retries": 1,
}
- if 'anthropic_api_url' in credentials and credentials['anthropic_api_url']:
+ if credentials.get('anthropic_api_url'):
credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/')
credentials_kwargs['base_url'] = credentials['anthropic_api_url']
diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py
index e81a120fa0..63a0b5c8be 100644
--- a/api/core/model_runtime/model_providers/azure_openai/_constant.py
+++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py
@@ -49,7 +49,7 @@ LLM_BASE_MODELS = [
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
- ModelPropertyKey.CONTEXT_SIZE: 4096,
+ ModelPropertyKey.CONTEXT_SIZE: 16385,
},
parameter_rules=[
ParameterRule(
@@ -68,11 +68,25 @@ LLM_BASE_MODELS = [
name='frequency_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
- _get_max_tokens(default=512, min_val=1, max_val=4096)
+ _get_max_tokens(default=512, min_val=1, max_val=4096),
+ ParameterRule(
+ name='response_format',
+ label=I18nObject(
+ zh_Hans='回复格式',
+ en_US='response_format'
+ ),
+ type='string',
+ help=I18nObject(
+ zh_Hans='指定模型必须输出的格式',
+ en_US='specifying the format that the model must output'
+ ),
+ required=False,
+ options=['text', 'json_object']
+ ),
],
pricing=PriceConfig(
- input=0.001,
- output=0.002,
+ input=0.0005,
+ output=0.0015,
unit=0.001,
currency='USD',
)
@@ -482,6 +496,310 @@ LLM_BASE_MODELS = [
)
)
),
+ AzureBaseModel(
+ base_model_name='gpt-4o',
+ entity=AIModelEntity(
+ model='fake-deployment-name',
+ label=I18nObject(
+ en_US='fake-deployment-name-label',
+ ),
+ model_type=ModelType.LLM,
+ features=[
+ ModelFeature.AGENT_THOUGHT,
+ ModelFeature.VISION,
+ ModelFeature.MULTI_TOOL_CALL,
+ ModelFeature.STREAM_TOOL_CALL,
+ ],
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={
+ ModelPropertyKey.MODE: LLMMode.CHAT.value,
+ ModelPropertyKey.CONTEXT_SIZE: 128000,
+ },
+ parameter_rules=[
+ ParameterRule(
+ name='temperature',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
+ ),
+ ParameterRule(
+ name='top_p',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
+ ),
+ ParameterRule(
+ name='presence_penalty',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
+ ),
+ ParameterRule(
+ name='frequency_penalty',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
+ ),
+ _get_max_tokens(default=512, min_val=1, max_val=4096),
+ ParameterRule(
+ name='seed',
+ label=I18nObject(
+ zh_Hans='种子',
+ en_US='Seed'
+ ),
+ type='int',
+ help=I18nObject(
+ zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
+ en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
+ ),
+ required=False,
+ precision=2,
+ min=0,
+ max=1,
+ ),
+ ParameterRule(
+ name='response_format',
+ label=I18nObject(
+ zh_Hans='回复格式',
+ en_US='response_format'
+ ),
+ type='string',
+ help=I18nObject(
+ zh_Hans='指定模型必须输出的格式',
+ en_US='specifying the format that the model must output'
+ ),
+ required=False,
+ options=['text', 'json_object']
+ ),
+ ],
+ pricing=PriceConfig(
+ input=5.00,
+ output=15.00,
+ unit=0.000001,
+ currency='USD',
+ )
+ )
+ ),
+ AzureBaseModel(
+ base_model_name='gpt-4o-2024-05-13',
+ entity=AIModelEntity(
+ model='fake-deployment-name',
+ label=I18nObject(
+ en_US='fake-deployment-name-label',
+ ),
+ model_type=ModelType.LLM,
+ features=[
+ ModelFeature.AGENT_THOUGHT,
+ ModelFeature.VISION,
+ ModelFeature.MULTI_TOOL_CALL,
+ ModelFeature.STREAM_TOOL_CALL,
+ ],
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={
+ ModelPropertyKey.MODE: LLMMode.CHAT.value,
+ ModelPropertyKey.CONTEXT_SIZE: 128000,
+ },
+ parameter_rules=[
+ ParameterRule(
+ name='temperature',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
+ ),
+ ParameterRule(
+ name='top_p',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
+ ),
+ ParameterRule(
+ name='presence_penalty',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
+ ),
+ ParameterRule(
+ name='frequency_penalty',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
+ ),
+ _get_max_tokens(default=512, min_val=1, max_val=4096),
+ ParameterRule(
+ name='seed',
+ label=I18nObject(
+ zh_Hans='种子',
+ en_US='Seed'
+ ),
+ type='int',
+ help=I18nObject(
+ zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
+ en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
+ ),
+ required=False,
+ precision=2,
+ min=0,
+ max=1,
+ ),
+ ParameterRule(
+ name='response_format',
+ label=I18nObject(
+ zh_Hans='回复格式',
+ en_US='response_format'
+ ),
+ type='string',
+ help=I18nObject(
+ zh_Hans='指定模型必须输出的格式',
+ en_US='specifying the format that the model must output'
+ ),
+ required=False,
+ options=['text', 'json_object']
+ ),
+ ],
+ pricing=PriceConfig(
+ input=5.00,
+ output=15.00,
+ unit=0.000001,
+ currency='USD',
+ )
+ )
+ ),
+ AzureBaseModel(
+ base_model_name='gpt-4-turbo',
+ entity=AIModelEntity(
+ model='fake-deployment-name',
+ label=I18nObject(
+ en_US='fake-deployment-name-label',
+ ),
+ model_type=ModelType.LLM,
+ features=[
+ ModelFeature.AGENT_THOUGHT,
+ ModelFeature.VISION,
+ ModelFeature.MULTI_TOOL_CALL,
+ ModelFeature.STREAM_TOOL_CALL,
+ ],
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={
+ ModelPropertyKey.MODE: LLMMode.CHAT.value,
+ ModelPropertyKey.CONTEXT_SIZE: 128000,
+ },
+ parameter_rules=[
+ ParameterRule(
+ name='temperature',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
+ ),
+ ParameterRule(
+ name='top_p',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
+ ),
+ ParameterRule(
+ name='presence_penalty',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
+ ),
+ ParameterRule(
+ name='frequency_penalty',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
+ ),
+ _get_max_tokens(default=512, min_val=1, max_val=4096),
+ ParameterRule(
+ name='seed',
+ label=I18nObject(
+ zh_Hans='种子',
+ en_US='Seed'
+ ),
+ type='int',
+ help=I18nObject(
+ zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
+ en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
+ ),
+ required=False,
+ precision=2,
+ min=0,
+ max=1,
+ ),
+ ParameterRule(
+ name='response_format',
+ label=I18nObject(
+ zh_Hans='回复格式',
+ en_US='response_format'
+ ),
+ type='string',
+ help=I18nObject(
+ zh_Hans='指定模型必须输出的格式',
+ en_US='specifying the format that the model must output'
+ ),
+ required=False,
+ options=['text', 'json_object']
+ ),
+ ],
+ pricing=PriceConfig(
+ input=0.01,
+ output=0.03,
+ unit=0.001,
+ currency='USD',
+ )
+ )
+ ),
+ AzureBaseModel(
+ base_model_name='gpt-4-turbo-2024-04-09',
+ entity=AIModelEntity(
+ model='fake-deployment-name',
+ label=I18nObject(
+ en_US='fake-deployment-name-label',
+ ),
+ model_type=ModelType.LLM,
+ features=[
+ ModelFeature.AGENT_THOUGHT,
+ ModelFeature.VISION,
+ ModelFeature.MULTI_TOOL_CALL,
+ ModelFeature.STREAM_TOOL_CALL,
+ ],
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={
+ ModelPropertyKey.MODE: LLMMode.CHAT.value,
+ ModelPropertyKey.CONTEXT_SIZE: 128000,
+ },
+ parameter_rules=[
+ ParameterRule(
+ name='temperature',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
+ ),
+ ParameterRule(
+ name='top_p',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
+ ),
+ ParameterRule(
+ name='presence_penalty',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
+ ),
+ ParameterRule(
+ name='frequency_penalty',
+ **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
+ ),
+ _get_max_tokens(default=512, min_val=1, max_val=4096),
+ ParameterRule(
+ name='seed',
+ label=I18nObject(
+ zh_Hans='种子',
+ en_US='Seed'
+ ),
+ type='int',
+ help=I18nObject(
+ zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
+ en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
+ ),
+ required=False,
+ precision=2,
+ min=0,
+ max=1,
+ ),
+ ParameterRule(
+ name='response_format',
+ label=I18nObject(
+ zh_Hans='回复格式',
+ en_US='response_format'
+ ),
+ type='string',
+ help=I18nObject(
+ zh_Hans='指定模型必须输出的格式',
+ en_US='specifying the format that the model must output'
+ ),
+ required=False,
+ options=['text', 'json_object']
+ ),
+ ],
+ pricing=PriceConfig(
+ input=0.01,
+ output=0.03,
+ unit=0.001,
+ currency='USD',
+ )
+ )
+ ),
AzureBaseModel(
base_model_name='gpt-4-vision-preview',
entity=AIModelEntity(
diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml
index 828698acc7..b9f33a8ff2 100644
--- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml
+++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml
@@ -59,6 +59,9 @@ model_credential_schema:
- label:
en_US: 2023-12-01-preview
value: 2023-12-01-preview
+ - label:
+ en_US: '2024-02-01'
+ value: '2024-02-01'
placeholder:
zh_Hans: 在此选择您的 API 版本
en_US: Select your API Version here
@@ -99,6 +102,24 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
+ - label:
+ en_US: gpt-4o
+ value: gpt-4o
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: gpt-4o-2024-05-13
+ value: gpt-4o-2024-05-13
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: gpt-4-turbo
+ value: gpt-4-turbo
+ show_on:
+ - variable: __model_type
+ value: llm
- label:
en_US: gpt-4-turbo-2024-04-09
value: gpt-4-turbo-2024-04-09
diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml
index e7cf059ea5..04849500dc 100644
--- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml
+++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml
@@ -6,7 +6,7 @@ features:
- agent-thought
model_properties:
mode: chat
- context_size: 4000
+ context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml
index fb5d73b068..f91329c77a 100644
--- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml
+++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml
@@ -6,7 +6,7 @@ features:
- agent-thought
model_properties:
mode: chat
- context_size: 192000
+ context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml
new file mode 100644
index 0000000000..bf72e82296
--- /dev/null
+++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml
@@ -0,0 +1,45 @@
+model: baichuan3-turbo-128k
+label:
+ en_US: Baichuan3-Turbo-128k
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 128000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: max_tokens
+ use_template: max_tokens
+ required: true
+ default: 8000
+ min: 1
+ max: 128000
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ default: 1
+ min: 1
+ max: 2
+ - name: with_search_enhance
+ label:
+ zh_Hans: 搜索增强
+ en_US: Search Enhance
+ type: boolean
+ help:
+ zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
+ en_US: Allow the model to perform external search to enhance the generation results.
+ required: false
diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml
new file mode 100644
index 0000000000..85882519b8
--- /dev/null
+++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml
@@ -0,0 +1,45 @@
+model: baichuan3-turbo
+label:
+ en_US: Baichuan3-Turbo
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 32000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: max_tokens
+ use_template: max_tokens
+ required: true
+ default: 8000
+ min: 1
+ max: 32000
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ default: 1
+ min: 1
+ max: 2
+ - name: with_search_enhance
+ label:
+ zh_Hans: 搜索增强
+ en_US: Search Enhance
+ type: boolean
+ help:
+ zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
+ en_US: Allow the model to perform external search to enhance the generation results.
+ required: false
diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml
new file mode 100644
index 0000000000..f8c6566081
--- /dev/null
+++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml
@@ -0,0 +1,45 @@
+model: baichuan4
+label:
+ en_US: Baichuan4
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 32000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: max_tokens
+ use_template: max_tokens
+ required: true
+ default: 8000
+ min: 1
+ max: 32000
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ default: 1
+ min: 1
+ max: 2
+ - name: with_search_enhance
+ label:
+ zh_Hans: 搜索增强
+ en_US: Search Enhance
+ type: boolean
+ help:
+ zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
+ en_US: Allow the model to perform external search to enhance the generation results.
+ required: false
diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py
index 639f6a21ce..d7d8b7c91b 100644
--- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py
+++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py
@@ -51,26 +51,29 @@ class BaichuanModel:
'baichuan2-turbo': 'Baichuan2-Turbo',
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
'baichuan2-53b': 'Baichuan2-53B',
+ 'baichuan3-turbo': 'Baichuan3-Turbo',
+ 'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k',
+ 'baichuan4': 'Baichuan4',
}[model]
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
- resp = response.json()
- choices = resp.get('choices', [])
- message = BaichuanMessage(content='', role='assistant')
- for choice in choices:
- message.content += choice['message']['content']
- message.role = choice['message']['role']
- if choice['finish_reason']:
- message.stop_reason = choice['finish_reason']
+ resp = response.json()
+ choices = resp.get('choices', [])
+ message = BaichuanMessage(content='', role='assistant')
+ for choice in choices:
+ message.content += choice['message']['content']
+ message.role = choice['message']['role']
+ if choice['finish_reason']:
+ message.stop_reason = choice['finish_reason']
- if 'usage' in resp:
- message.usage = {
- 'prompt_tokens': resp['usage']['prompt_tokens'],
- 'completion_tokens': resp['usage']['completion_tokens'],
- 'total_tokens': resp['usage']['total_tokens'],
- }
-
- return message
+ if 'usage' in resp:
+ message.usage = {
+ 'prompt_tokens': resp['usage']['prompt_tokens'],
+ 'completion_tokens': resp['usage']['completion_tokens'],
+ 'total_tokens': resp['usage']['total_tokens'],
+ }
+
+ return message
def _handle_chat_stream_generate_response(self, response) -> Generator:
for line in response.iter_lines():
@@ -89,7 +92,7 @@ class BaichuanModel:
# save stop reason temporarily
stop_reason = ''
for choice in choices:
- if 'finish_reason' in choice and choice['finish_reason']:
+ if choice.get('finish_reason'):
stop_reason = choice['finish_reason']
if len(choice['delta']['content']) == 0:
@@ -110,7 +113,8 @@ class BaichuanModel:
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
parameters: dict[str, Any]) \
-> dict[str, Any]:
- if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
+ if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
+ or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
prompt_messages = []
for message in messages:
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
@@ -143,7 +147,8 @@ class BaichuanModel:
raise BadRequestError(f"Unknown model: {model}")
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
- if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
+ if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
+ or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
# there is no secret key for turbo api
return {
'Content-Type': 'application/json',
@@ -160,7 +165,8 @@ class BaichuanModel:
parameters: dict[str, Any], timeout: int) \
-> Union[Generator, BaichuanMessage]:
- if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
+ if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
+ or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
else:
raise BadRequestError(f"Unknown model: {model}")
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml
index 7f4d2035cc..732d1c2a93 100644
--- a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml
+++ b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml
@@ -12,6 +12,7 @@
- meta.llama3-70b-instruct-v1:0
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
+- mistral.mistral-small-2402-v1:0
- mistral.mistral-large-2402-v1:0
- mistral.mixtral-8x7b-instruct-v0:1
- mistral.mistral-7b-instruct-v0:2
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml
index 73fe5567fc..181b192769 100644
--- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml
+++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml
@@ -51,7 +51,7 @@ parameter_rules:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
- input: '0.003'
- output: '0.015'
+ input: '0.00025'
+ output: '0.00125'
unit: '0.001'
currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml
index cb11df0b60..b782faddba 100644
--- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml
+++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml
@@ -50,7 +50,7 @@ parameter_rules:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
- input: '0.00025'
- output: '0.00125'
+ input: '0.003'
+ output: '0.015'
unit: '0.001'
currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py
index 81a9ce2f00..1386d680a4 100644
--- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py
+++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py
@@ -358,26 +358,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return message_dict
- def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
- :param messages: prompt messages or message string
+ :param prompt_messages: prompt messages or message string
:param tools: tools for tool calling
:return:md = genai.GenerativeModel(model)
"""
prefix = model.split('.')[0]
model_name = model.split('.')[1]
- if isinstance(messages, str):
- prompt = messages
+ if isinstance(prompt_messages, str):
+ prompt = prompt_messages
else:
- prompt = self._convert_messages_to_prompt(messages, prefix, model_name)
+ prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
return self._get_num_tokens_by_gpt2(prompt)
-
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-small-2402-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-small-2402-v1.0.yaml
new file mode 100644
index 0000000000..582f4a6d9f
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-small-2402-v1.0.yaml
@@ -0,0 +1,27 @@
+model: mistral.mistral-small-2402-v1:0
+label:
+ en_US: Mistral Small
+model_type: llm
+model_properties:
+ mode: completion
+ context_size: 32000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ required: false
+ default: 0.7
+ - name: top_p
+ use_template: top_p
+ required: false
+ default: 1
+ - name: max_tokens
+ use_template: max_tokens
+ required: true
+ default: 512
+ min: 1
+ max: 4096
+pricing:
+ input: '0.001'
+ output: '0.03'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml
index 5419ff530b..afbea06a3e 100644
--- a/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml
+++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml
@@ -1,3 +1,4 @@
- amazon.titan-embed-text-v1
+- amazon.titan-embed-text-v2:0
- cohere.embed-english-v3
- cohere.embed-multilingual-v3
diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml
index 6a1cf75be1..e5a55971a1 100644
--- a/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml
+++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml
@@ -4,5 +4,5 @@ model_properties:
context_size: 8192
pricing:
input: '0.0001'
- unit: '0.001'
+ unit: '0.0001'
currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v2.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v2.yaml
new file mode 100644
index 0000000000..5069efeb10
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v2.yaml
@@ -0,0 +1,8 @@
+model: amazon.titan-embed-text-v2:0
+model_type: text-embedding
+model_properties:
+ context_size: 8192
+pricing:
+ input: '0.00002'
+ unit: '0.00001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py
index 69436cd737..84b23d4a27 100644
--- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py
@@ -59,15 +59,15 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
model_prefix = model.split('.')[0]
if model_prefix == "amazon" :
- for text in texts:
- body = {
+ for text in texts:
+ body = {
"inputText": text,
- }
- response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
- embeddings.extend([response_body.get('embedding')])
- token_usage += response_body.get('inputTextTokenCount')
- logger.warning(f'Total Tokens: {token_usage}')
- result = TextEmbeddingResult(
+ }
+ response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
+ embeddings.extend([response_body.get('embedding')])
+ token_usage += response_body.get('inputTextTokenCount')
+ logger.warning(f'Total Tokens: {token_usage}')
+ result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
@@ -75,20 +75,20 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
credentials=credentials,
tokens=token_usage
)
- )
- return result
-
+ )
+ return result
+
if model_prefix == "cohere" :
- input_type = 'search_document' if len(texts) > 1 else 'search_query'
- for text in texts:
- body = {
+ input_type = 'search_document' if len(texts) > 1 else 'search_query'
+ for text in texts:
+ body = {
"texts": [text],
"input_type": input_type,
- }
- response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
- embeddings.extend(response_body.get('embeddings'))
- token_usage += len(text)
- result = TextEmbeddingResult(
+ }
+ response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
+ embeddings.extend(response_body.get('embeddings'))
+ token_usage += len(text)
+ result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
@@ -96,9 +96,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
credentials=credentials,
tokens=token_usage
)
- )
- return result
-
+ )
+ return result
+
#others
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
@@ -183,7 +183,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
)
return usage
-
+
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
"""
Map client error to invoke error
@@ -212,9 +212,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
content_type = 'application/json'
try:
response = bedrock_runtime.invoke_model(
- body=json.dumps(body),
- modelId=model,
- accept=accept,
+ body=json.dumps(body),
+ modelId=model,
+ accept=accept,
contentType=content_type
)
response_body = json.loads(response.get('body').read().decode('utf-8'))
diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py
index 12dc75aece..e83d08af71 100644
--- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py
+++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py
@@ -1,6 +1,5 @@
import logging
from collections.abc import Generator
-from os.path import join
from typing import Optional, cast
from httpx import Timeout
@@ -19,6 +18,7 @@ from openai import (
)
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message import FunctionCall
+from yarl import URL
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
@@ -265,7 +265,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1",
- "base_url": join(credentials['api_base'], 'v1')
+ "base_url": str(URL(credentials['api_base']) / 'v1')
}
return client_kwargs
diff --git a/api/core/model_runtime/model_providers/cohere/cohere.yaml b/api/core/model_runtime/model_providers/cohere/cohere.yaml
index c889a6bfe0..bd40057fe9 100644
--- a/api/core/model_runtime/model_providers/cohere/cohere.yaml
+++ b/api/core/model_runtime/model_providers/cohere/cohere.yaml
@@ -32,6 +32,15 @@ provider_credential_schema:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
show_on: [ ]
+ - variable: base_url
+ label:
+ zh_Hans: API Base
+ en_US: API Base
+ type: text-input
+ required: false
+ placeholder:
+ zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1
+ en_US: Enter your API Base, e.g. https://api.cohere.ai/v1
model_credential_schema:
model:
label:
@@ -70,3 +79,12 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
+ - variable: base_url
+ label:
+ zh_Hans: API Base
+ en_US: API Base
+ type: text-input
+ required: false
+ placeholder:
+ zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1
+ en_US: Enter your API Base, e.g. https://api.cohere.ai/v1
diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py
index 6ace77b813..f9fae5e8ca 100644
--- a/api/core/model_runtime/model_providers/cohere/llm/llm.py
+++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py
@@ -173,7 +173,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# initialize client
- client = cohere.Client(credentials.get('api_key'))
+ client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
if stop:
model_parameters['end_sequences'] = stop
@@ -233,7 +233,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return response
- def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
+ def _handle_generate_stream_response(self, model: str, credentials: dict,
+ response: Iterator[GenerateStreamedResponse],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
@@ -317,7 +318,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# initialize client
- client = cohere.Client(credentials.get('api_key'))
+ client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
if stop:
model_parameters['stop_sequences'] = stop
@@ -636,7 +637,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: number of tokens
"""
# initialize client
- client = cohere.Client(credentials.get('api_key'))
+ client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
response = client.tokenize(
text=text,
diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py
index 4194f27eb9..d2fdb30c6f 100644
--- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py
+++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py
@@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel):
)
# initialize client
- client = cohere.Client(credentials.get('api_key'))
+ client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
response = client.rerank(
query=query,
documents=docs,
diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py
index 8269a41810..0540fb740f 100644
--- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py
@@ -141,7 +141,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
return []
# initialize client
- client = cohere.Client(credentials.get('api_key'))
+ client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
response = client.tokenize(
text=text,
@@ -180,7 +180,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
:return: embeddings and used tokens
"""
# initialize client
- client = cohere.Client(credentials.get('api_key'))
+ client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
# call embedding model
response = client.embed(
diff --git a/api/core/model_runtime/model_providers/deepseek/__init__.py b/api/core/model_runtime/model_providers/deepseek/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/deepseek/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/deepseek/_assets/icon_l_en.svg
new file mode 100644
index 0000000000..425494404f
--- /dev/null
+++ b/api/core/model_runtime/model_providers/deepseek/_assets/icon_l_en.svg
@@ -0,0 +1,22 @@
+
diff --git a/api/core/model_runtime/model_providers/deepseek/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/deepseek/_assets/icon_s_en.svg
new file mode 100644
index 0000000000..aa854a7504
--- /dev/null
+++ b/api/core/model_runtime/model_providers/deepseek/_assets/icon_s_en.svg
@@ -0,0 +1,3 @@
+
diff --git a/api/core/model_runtime/model_providers/deepseek/deepseek.py b/api/core/model_runtime/model_providers/deepseek/deepseek.py
new file mode 100644
index 0000000000..d61fd4ddc8
--- /dev/null
+++ b/api/core/model_runtime/model_providers/deepseek/deepseek.py
@@ -0,0 +1,33 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+
+class DeepSeekProvider(ModelProvider):
+
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ """
+ Validate provider credentials
+ if validate failed, raise exception
+
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+ """
+ try:
+ model_instance = self.get_model_instance(ModelType.LLM)
+
+ # Use `deepseek-chat` model for validate,
+ # no matter what model you pass in, text completion model or chat model
+ model_instance.validate_credentials(
+ model='deepseek-chat',
+ credentials=credentials
+ )
+ except CredentialsValidateFailedError as ex:
+ raise ex
+ except Exception as ex:
+ logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+ raise ex
diff --git a/api/core/model_runtime/model_providers/deepseek/deepseek.yaml b/api/core/model_runtime/model_providers/deepseek/deepseek.yaml
new file mode 100644
index 0000000000..16abd358d6
--- /dev/null
+++ b/api/core/model_runtime/model_providers/deepseek/deepseek.yaml
@@ -0,0 +1,41 @@
+provider: deepseek
+label:
+ en_US: deepseek
+ zh_Hans: 深度求索
+description:
+ en_US: Models provided by deepseek, such as deepseek-chat、deepseek-coder.
+ zh_Hans: 深度求索提供的模型,例如 deepseek-chat、deepseek-coder 。
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.svg
+background: "#c0cdff"
+help:
+ title:
+ en_US: Get your API Key from deepseek
+ zh_Hans: 从深度求索获取 API Key
+ url:
+ en_US: https://platform.deepseek.com/api_keys
+supported_model_types:
+ - llm
+configurate_methods:
+ - predefined-model
+provider_credential_schema:
+ credential_form_schemas:
+ - variable: api_key
+ label:
+ en_US: API Key
+ type: secret-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入您的 API Key
+ en_US: Enter your API Key
+ - variable: endpoint_url
+ label:
+ zh_Hans: 自定义 API endpoint 地址
+ en_US: Custom API endpoint URL
+ type: text-input
+ required: false
+ placeholder:
+ zh_Hans: Base URL, e.g. https://api.deepseek.com/v1 or https://api.deepseek.com
+ en_US: Base URL, e.g. https://api.deepseek.com/v1 or https://api.deepseek.com
diff --git a/api/core/model_runtime/model_providers/deepseek/llm/__init__.py b/api/core/model_runtime/model_providers/deepseek/llm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/deepseek/llm/_position.yaml b/api/core/model_runtime/model_providers/deepseek/llm/_position.yaml
new file mode 100644
index 0000000000..43d03f2ee9
--- /dev/null
+++ b/api/core/model_runtime/model_providers/deepseek/llm/_position.yaml
@@ -0,0 +1,2 @@
+- deepseek-chat
+- deepseek-coder
diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml
new file mode 100644
index 0000000000..3a5a63fa61
--- /dev/null
+++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml
@@ -0,0 +1,64 @@
+model: deepseek-chat
+label:
+ zh_Hans: deepseek-chat
+ en_US: deepseek-chat
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 32000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ type: float
+ default: 1
+ min: 0.0
+ max: 2.0
+ help:
+ zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
+ en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
+ - name: max_tokens
+ use_template: max_tokens
+ type: int
+ default: 4096
+ min: 1
+ max: 32000
+ help:
+ zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
+ en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
+ - name: top_p
+ use_template: top_p
+ type: float
+ default: 1
+ min: 0.01
+ max: 1.00
+ help:
+ zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
+ en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
+ - name: logprobs
+ help:
+ zh_Hans: 是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。
+ en_US: Whether to return the log probability of the output token. If true, returns the log probability of each output token in the content of message .
+ type: boolean
+ - name: top_logprobs
+ type: int
+ default: 0
+ min: 0
+ max: 20
+ help:
+ zh_Hans: 一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token,且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。
+ en_US: An integer N between 0 and 20, specifying that each output position returns the top N tokens with output probability, and returns the logarithmic probability of these tokens. When specifying this parameter, logprobs must be true.
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ default: 0
+ min: -2.0
+ max: 2.0
+ help:
+ zh_Hans: 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。
+ en_US: A number between -2.0 and 2.0. If the value is positive, new tokens are penalized based on their frequency of occurrence in existing text, reducing the likelihood that the model will repeat the same content.
+pricing:
+ input: '1'
+ output: '2'
+ unit: '0.000001'
+ currency: RMB
diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml
new file mode 100644
index 0000000000..8f156be101
--- /dev/null
+++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml
@@ -0,0 +1,26 @@
+model: deepseek-coder
+label:
+ zh_Hans: deepseek-coder
+ en_US: deepseek-coder
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 16000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ min: 0
+ max: 1
+ default: 0.5
+ - name: top_p
+ use_template: top_p
+ min: 0
+ max: 1
+ default: 1
+ - name: max_tokens
+ use_template: max_tokens
+ min: 1
+ max: 32000
+ default: 1024
diff --git a/api/core/model_runtime/model_providers/deepseek/llm/llm.py b/api/core/model_runtime/model_providers/deepseek/llm/llm.py
new file mode 100644
index 0000000000..bdb3823b60
--- /dev/null
+++ b/api/core/model_runtime/model_providers/deepseek/llm/llm.py
@@ -0,0 +1,113 @@
+from collections.abc import Generator
+from typing import Optional, Union
+from urllib.parse import urlparse
+
+import tiktoken
+
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import (
+ PromptMessage,
+ PromptMessageTool,
+)
+from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
+
+
+class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
+
+ def _invoke(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ stream: bool = True, user: Optional[str] = None) \
+ -> Union[LLMResult, Generator]:
+ self._add_custom_parameters(credentials)
+
+ return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ self._add_custom_parameters(credentials)
+ super().validate_credentials(model, credentials)
+
+
+ # refactored from openai model runtime, use cl100k_base for calculate token number
+ def _num_tokens_from_string(self, model: str, text: str,
+ tools: Optional[list[PromptMessageTool]] = None) -> int:
+ """
+ Calculate num tokens for text completion model with tiktoken package.
+
+ :param model: model name
+ :param text: prompt text
+ :param tools: tools for tool calling
+ :return: number of tokens
+ """
+ encoding = tiktoken.get_encoding("cl100k_base")
+ num_tokens = len(encoding.encode(text))
+
+ if tools:
+ num_tokens += self._num_tokens_for_tools(encoding, tools)
+
+ return num_tokens
+
+ # refactored from openai model runtime, use cl100k_base for calculate token number
+ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None) -> int:
+ """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
+
+ Official documentation: https://github.com/openai/openai-cookbook/blob/
+ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
+ encoding = tiktoken.get_encoding("cl100k_base")
+ tokens_per_message = 3
+ tokens_per_name = 1
+
+ num_tokens = 0
+ messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
+ for message in messages_dict:
+ num_tokens += tokens_per_message
+ for key, value in message.items():
+ # Cast str(value) in case the message value is not a string
+ # This occurs with function messages
+ # TODO: The current token calculation method for the image type is not implemented,
+ # which need to download the image and then get the resolution for calculation,
+ # and will increase the request delay
+ if isinstance(value, list):
+ text = ''
+ for item in value:
+ if isinstance(item, dict) and item['type'] == 'text':
+ text += item['text']
+
+ value = text
+
+ if key == "tool_calls":
+ for tool_call in value:
+ for t_key, t_value in tool_call.items():
+ num_tokens += len(encoding.encode(t_key))
+ if t_key == "function":
+ for f_key, f_value in t_value.items():
+ num_tokens += len(encoding.encode(f_key))
+ num_tokens += len(encoding.encode(f_value))
+ else:
+ num_tokens += len(encoding.encode(t_key))
+ num_tokens += len(encoding.encode(t_value))
+ else:
+ num_tokens += len(encoding.encode(str(value)))
+
+ if key == "name":
+ num_tokens += tokens_per_name
+
+ # every reply is primed with assistant
+ num_tokens += 3
+
+ if tools:
+ num_tokens += self._num_tokens_for_tools(encoding, tools)
+
+ return num_tokens
+
+ @staticmethod
+ def _add_custom_parameters(credentials: dict) -> None:
+ credentials['mode'] = 'chat'
+ credentials['openai_api_key']=credentials['api_key']
+ if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
+ credentials['openai_api_base']='https://api.deepseek.com'
+ else:
+ parsed_url = urlparse(credentials['endpoint_url'])
+ credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}"
+
diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml
new file mode 100644
index 0000000000..24b1c5af8a
--- /dev/null
+++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml
@@ -0,0 +1,39 @@
+model: gemini-1.5-flash-latest
+label:
+ en_US: Gemini 1.5 Flash
+model_type: llm
+features:
+ - agent-thought
+ - vision
+ - tool-call
+ - stream-tool-call
+model_properties:
+ mode: chat
+ context_size: 1048576
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: max_tokens_to_sample
+ use_template: max_tokens
+ required: true
+ default: 8192
+ min: 1
+ max: 8192
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.00'
+ output: '0.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py
index 27912b13cc..5a674fdeee 100644
--- a/api/core/model_runtime/model_providers/google/llm/llm.py
+++ b/api/core/model_runtime/model_providers/google/llm/llm.py
@@ -204,6 +204,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
stream=stream,
safety_settings=safety_settings,
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
+ request_options={"timeout": 600}
)
if stream:
diff --git a/api/core/model_runtime/model_providers/jina/jina.yaml b/api/core/model_runtime/model_providers/jina/jina.yaml
index 935546234b..23e18ad75f 100644
--- a/api/core/model_runtime/model_providers/jina/jina.yaml
+++ b/api/core/model_runtime/model_providers/jina/jina.yaml
@@ -19,6 +19,7 @@ supported_model_types:
- rerank
configurate_methods:
- predefined-model
+ - customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
@@ -29,3 +30,40 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter your model name
+ zh_Hans: 输入模型名称
+ credential_form_schemas:
+ - variable: api_key
+ label:
+ en_US: API Key
+ type: secret-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入您的 API Key
+ en_US: Enter your API Key
+ - variable: base_url
+ label:
+ zh_Hans: 服务器 URL
+ en_US: Base URL
+ type: text-input
+ required: true
+ placeholder:
+ zh_Hans: Base URL, e.g. https://api.jina.ai/v1
+ en_US: Base URL, e.g. https://api.jina.ai/v1
+ default: 'https://api.jina.ai/v1'
+ - variable: context_size
+ label:
+ zh_Hans: 上下文大小
+ en_US: Context size
+ placeholder:
+ zh_Hans: 输入上下文大小
+ en_US: Enter context size
+ required: false
+ type: text-input
+ default: '8192'
diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py
index f644ea6512..de7e038b9f 100644
--- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py
+++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py
@@ -2,6 +2,8 @@ from typing import Optional
import httpx
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
@@ -38,9 +40,13 @@ class JinaRerankModel(RerankModel):
if len(docs) == 0:
return RerankResult(model=model, docs=[])
+ base_url = credentials.get('base_url', 'https://api.jina.ai/v1')
+ if base_url.endswith('/'):
+ base_url = base_url[:-1]
+
try:
response = httpx.post(
- "https://api.jina.ai/v1/rerank",
+ base_url + '/rerank',
json={
"model": model,
"query": query,
@@ -103,3 +109,19 @@ class JinaRerankModel(RerankModel):
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
}
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.RERANK,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
+ }
+ )
+
+ return entity
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py
index da922232c0..74a1aabf7a 100644
--- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py
@@ -4,7 +4,8 @@ from typing import Optional
from requests import post
-from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
@@ -23,8 +24,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Jina text embedding model.
"""
- api_base: str = 'https://api.jina.ai/v1/embeddings'
- models: list[str] = ['jina-embeddings-v2-base-en', 'jina-embeddings-v2-small-en', 'jina-embeddings-v2-base-zh', 'jina-embeddings-v2-base-de']
+ api_base: str = 'https://api.jina.ai/v1'
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
@@ -39,11 +39,14 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
:return: embeddings result
"""
api_key = credentials['api_key']
- if model not in self.models:
- raise InvokeBadRequestError('Invalid model name')
if not api_key:
raise CredentialsValidateFailedError('api_key is required')
- url = self.api_base
+
+ base_url = credentials.get('base_url', self.api_base)
+ if base_url.endswith('/'):
+ base_url = base_url[:-1]
+
+ url = base_url + '/embeddings'
headers = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
@@ -70,7 +73,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
elif response.status_code == 500:
raise InvokeServerUnavailableError(msg)
else:
- raise InvokeError(msg)
+ raise InvokeBadRequestError(msg)
except JSONDecodeError as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
@@ -118,8 +121,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
"""
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
- except InvokeAuthorizationError:
- raise CredentialsValidateFailedError('Invalid api key')
+ except Exception as e:
+ raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
@@ -137,7 +140,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
InvokeAuthorizationError
],
InvokeBadRequestError: [
- KeyError
+ KeyError,
+ InvokeBadRequestError
]
}
@@ -170,3 +174,19 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
)
return usage
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.TEXT_EMBEDDING,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
+ }
+ )
+
+ return entity
diff --git a/api/core/model_runtime/model_providers/localai/localai.yaml b/api/core/model_runtime/model_providers/localai/localai.yaml
index a870914632..864dd7a30c 100644
--- a/api/core/model_runtime/model_providers/localai/localai.yaml
+++ b/api/core/model_runtime/model_providers/localai/localai.yaml
@@ -15,6 +15,8 @@ help:
supported_model_types:
- llm
- text-embedding
+ - rerank
+ - speech2text
configurate_methods:
- customizable-model
model_credential_schema:
@@ -57,6 +59,9 @@ model_credential_schema:
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
- variable: context_size
+ show_on:
+ - variable: __model_type
+ value: llm
label:
zh_Hans: 上下文大小
en_US: Context size
diff --git a/api/core/model_runtime/model_providers/localai/rerank/__init__.py b/api/core/model_runtime/model_providers/localai/rerank/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/localai/rerank/rerank.py b/api/core/model_runtime/model_providers/localai/rerank/rerank.py
new file mode 100644
index 0000000000..c8ba9a6c7c
--- /dev/null
+++ b/api/core/model_runtime/model_providers/localai/rerank/rerank.py
@@ -0,0 +1,136 @@
+from json import dumps
+from typing import Optional
+
+import httpx
+from requests import post
+from yarl import URL
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+
+
+class LocalaiRerankModel(RerankModel):
+ """
+ LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here.
+ """
+
+ def _invoke(self, model: str, credentials: dict,
+ query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
+ user: Optional[str] = None) -> RerankResult:
+ """
+ Invoke rerank model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param query: search query
+ :param docs: docs for reranking
+ :param score_threshold: score threshold
+ :param top_n: top n documents to return
+ :param user: unique user id
+ :return: rerank result
+ """
+ if len(docs) == 0:
+ return RerankResult(model=model, docs=[])
+
+ server_url = credentials['server_url']
+ model_name = model
+
+ if not server_url:
+ raise CredentialsValidateFailedError('server_url is required')
+ if not model_name:
+ raise CredentialsValidateFailedError('model_name is required')
+
+ url = server_url
+ headers = {
+ 'Authorization': f"Bearer {credentials.get('api_key')}",
+ 'Content-Type': 'application/json'
+ }
+
+ data = {
+ "model": model_name,
+ "query": query,
+ "documents": docs,
+ "top_n": top_n
+ }
+
+ try:
+ response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10)
+ response.raise_for_status()
+ results = response.json()
+
+ rerank_documents = []
+ for result in results['results']:
+ rerank_document = RerankDocument(
+ index=result['index'],
+ text=result['document']['text'],
+ score=result['relevance_score'],
+ )
+ if score_threshold is None or result['relevance_score'] >= score_threshold:
+ rerank_documents.append(rerank_document)
+
+ return RerankResult(model=model, docs=rerank_documents)
+ except httpx.HTTPStatusError as e:
+ raise InvokeServerUnavailableError(str(e))
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+
+ self._invoke(
+ model=model,
+ credentials=credentials,
+ query="What is the capital of the United States?",
+ docs=[
+ "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+ "Census, Carson City had a population of 55,274.",
+ "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+ "are a political division controlled by the United States. Its capital is Saipan.",
+ ],
+ score_threshold=0.8
+ )
+ except Exception as ex:
+ raise CredentialsValidateFailedError(str(ex))
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ """
+ return {
+ InvokeConnectionError: [httpx.ConnectError],
+ InvokeServerUnavailableError: [httpx.RemoteProtocolError],
+ InvokeRateLimitError: [],
+ InvokeAuthorizationError: [httpx.HTTPStatusError],
+ InvokeBadRequestError: [httpx.RequestError]
+ }
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.RERANK,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={}
+ )
+
+ return entity
diff --git a/api/core/model_runtime/model_providers/localai/speech2text/__init__.py b/api/core/model_runtime/model_providers/localai/speech2text/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py
new file mode 100644
index 0000000000..d7403aff4f
--- /dev/null
+++ b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py
@@ -0,0 +1,101 @@
+from typing import IO, Optional
+
+from requests import Request, Session
+from yarl import URL
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
+
+
+class LocalAISpeech2text(Speech2TextModel):
+ """
+ Model class for Local AI Text to speech model.
+ """
+
+ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
+ """
+ Invoke large language model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param file: audio file
+ :param user: unique user id
+ :return: text for given audio file
+ """
+
+ url = str(URL(credentials['server_url']) / "v1/audio/transcriptions")
+ data = {"model": model}
+ files = {"file": file}
+
+ session = Session()
+ request = Request("POST", url, data=data, files=files)
+ prepared_request = session.prepare_request(request)
+ response = session.send(prepared_request)
+
+ if 'error' in response.json():
+ raise InvokeServerUnavailableError("Empty response")
+
+ return response.json()["text"]
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ audio_file_path = self._get_demo_file_path()
+
+ with open(audio_file_path, 'rb') as audio_file:
+ self._invoke(model, credentials, audio_file)
+ except Exception as ex:
+ raise CredentialsValidateFailedError(str(ex))
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ return {
+ InvokeConnectionError: [
+ InvokeConnectionError
+ ],
+ InvokeServerUnavailableError: [
+ InvokeServerUnavailableError
+ ],
+ InvokeRateLimitError: [
+ InvokeRateLimitError
+ ],
+ InvokeAuthorizationError: [
+ InvokeAuthorizationError
+ ],
+ InvokeBadRequestError: [
+ InvokeBadRequestError
+ ],
+ }
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+ """
+ used to define customizable model schema
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(
+ en_US=model
+ ),
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_type=ModelType.SPEECH2TEXT,
+ model_properties={},
+ parameter_rules=[]
+ )
+
+ return entity
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab5-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab5-chat.yaml
index e0d81e76cc..2c1f79e2b7 100644
--- a/api/core/model_runtime/model_providers/minimax/llm/abab5-chat.yaml
+++ b/api/core/model_runtime/model_providers/minimax/llm/abab5-chat.yaml
@@ -18,6 +18,15 @@ parameter_rules:
default: 6144
min: 1
max: 6144
+ - name: mask_sensitive_info
+ type: boolean
+ default: true
+ label:
+ zh_Hans: 隐私保护
+ en_US: Moderate
+ help:
+ zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
+ en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml
index c0ad1e2fdf..6d29be0d0e 100644
--- a/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml
+++ b/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml
@@ -26,6 +26,15 @@ parameter_rules:
default: 6144
min: 1
max: 16384
+ - name: mask_sensitive_info
+ type: boolean
+ default: true
+ label:
+ zh_Hans: 隐私保护
+ en_US: Moderate
+ help:
+ zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
+ en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab5.5s-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab5.5s-chat.yaml
index 1ef7046459..aa42bb5739 100644
--- a/api/core/model_runtime/model_providers/minimax/llm/abab5.5s-chat.yaml
+++ b/api/core/model_runtime/model_providers/minimax/llm/abab5.5s-chat.yaml
@@ -24,6 +24,15 @@ parameter_rules:
default: 3072
min: 1
max: 8192
+ - name: mask_sensitive_info
+ type: boolean
+ default: true
+ label:
+ zh_Hans: 隐私保护
+ en_US: Moderate
+ help:
+ zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
+ en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml
index 4c487c598e..9188b6b53f 100644
--- a/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml
+++ b/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml
@@ -26,6 +26,15 @@ parameter_rules:
default: 2048
min: 1
max: 32768
+ - name: mask_sensitive_info
+ type: boolean
+ default: true
+ label:
+ zh_Hans: 隐私保护
+ en_US: Moderate
+ help:
+ zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
+ en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab6.5-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab6.5-chat.yaml
index ead61fb7ca..5d717d5f8c 100644
--- a/api/core/model_runtime/model_providers/minimax/llm/abab6.5-chat.yaml
+++ b/api/core/model_runtime/model_providers/minimax/llm/abab6.5-chat.yaml
@@ -26,6 +26,15 @@ parameter_rules:
default: 2048
min: 1
max: 8192
+ - name: mask_sensitive_info
+ type: boolean
+ default: true
+ label:
+ zh_Hans: 隐私保护
+ en_US: Moderate
+ help:
+ zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
+ en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab6.5s-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab6.5s-chat.yaml
index 2cd98abb22..4631fe67e4 100644
--- a/api/core/model_runtime/model_providers/minimax/llm/abab6.5s-chat.yaml
+++ b/api/core/model_runtime/model_providers/minimax/llm/abab6.5s-chat.yaml
@@ -26,6 +26,15 @@ parameter_rules:
default: 2048
min: 1
max: 245760
+ - name: mask_sensitive_info
+ type: boolean
+ default: true
+ label:
+ zh_Hans: 隐私保护
+ en_US: Moderate
+ help:
+ zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
+ en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
index 81ea2e165e..55747057c9 100644
--- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
+++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
@@ -20,16 +20,16 @@ class MinimaxChatCompletionPro:
Minimax Chat Completion Pro API, supports function calling
however, we do not have enough time and energy to implement it, but the parameters are reserved
"""
- def generate(self, model: str, api_key: str, group_id: str,
+ def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: list[MinimaxMessage], model_parameters: dict,
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
- -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
+ -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
"""
generate chat completion
"""
if not api_key or not group_id:
raise InvalidAPIKeyError('Invalid API key or group ID')
-
+
url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
extra_kwargs = {}
@@ -42,8 +42,11 @@ class MinimaxChatCompletionPro:
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
extra_kwargs['top_p'] = model_parameters['top_p']
+
+ if 'mask_sensitive_info' in model_parameters and type(model_parameters['mask_sensitive_info']) == bool:
+ extra_kwargs['mask_sensitive_info'] = model_parameters['mask_sensitive_info']
- if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']:
+ if model_parameters.get('plugin_web_search'):
extra_kwargs['plugins'] = [
'plugin_web_search'
]
@@ -61,7 +64,7 @@ class MinimaxChatCompletionPro:
# check if there is a system message
if len(prompt_messages) == 0:
raise BadRequestError('At least one message is required')
-
+
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
if prompt_messages[0].content:
bot_setting['content'] = prompt_messages[0].content
@@ -70,7 +73,7 @@ class MinimaxChatCompletionPro:
# check if there is a user message
if len(prompt_messages) == 0:
raise BadRequestError('At least one user message is required')
-
+
messages = [message.to_dict() for message in prompt_messages]
headers = {
@@ -89,21 +92,21 @@ class MinimaxChatCompletionPro:
if tools:
body['functions'] = tools
- body['function_call'] = { 'type': 'auto' }
+ body['function_call'] = {'type': 'auto'}
try:
response = post(
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
except Exception as e:
raise InternalServerError(e)
-
+
if response.status_code != 200:
raise InternalServerError(response.text)
-
+
if stream:
return self._handle_stream_chat_generate_response(response)
return self._handle_chat_generate_response(response)
-
+
def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
raise InternalServerError(msg)
@@ -127,7 +130,7 @@ class MinimaxChatCompletionPro:
code = response['base_resp']['status_code']
msg = response['base_resp']['status_msg']
self._handle_error(code, msg)
-
+
message = MinimaxMessage(
content=response['reply'],
role=MinimaxMessage.Role.ASSISTANT.value
@@ -144,7 +147,6 @@ class MinimaxChatCompletionPro:
"""
handle stream chat generate response
"""
- function_call_storage = None
for line in response.iter_lines():
if not line:
continue
@@ -158,54 +160,41 @@ class MinimaxChatCompletionPro:
msg = data['base_resp']['status_msg']
self._handle_error(code, msg)
- if data['reply'] or 'usage' in data and data['usage']:
+ # final chunk
+ if data['reply'] or data.get('usage'):
total_tokens = data['usage']['total_tokens']
- message = MinimaxMessage(
+ minimax_message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value,
content=''
)
- message.usage = {
+ minimax_message.usage = {
'prompt_tokens': 0,
'completion_tokens': total_tokens,
'total_tokens': total_tokens
}
- message.stop_reason = data['choices'][0]['finish_reason']
+ minimax_message.stop_reason = data['choices'][0]['finish_reason']
- if function_call_storage:
- function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
- function_call_message.function_call = function_call_storage
- yield function_call_message
+ choices = data.get('choices', [])
+ if len(choices) > 0:
+ for choice in choices:
+ message = choice['messages'][0]
+ # append function_call message
+ if 'function_call' in message:
+ function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
+ function_call_message.function_call = message['function_call']
+ yield function_call_message
- yield message
+ yield minimax_message
return
+ # partial chunk
choices = data.get('choices', [])
if len(choices) == 0:
continue
for choice in choices:
message = choice['messages'][0]
-
- if 'function_call' in message:
- if not function_call_storage:
- function_call_storage = message['function_call']
- if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
- function_call_storage['arguments'] = ''
- continue
- else:
- function_call_storage['arguments'] += message['function_call']['arguments']
- continue
- else:
- if function_call_storage:
- message['function_call'] = function_call_storage
- function_call_storage = None
-
- minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
-
- if 'function_call' in message:
- minimax_message.function_call = message['function_call']
-
+ # append text message
if 'text' in message:
- minimax_message.content = message['text']
-
- yield minimax_message
\ No newline at end of file
+ minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value)
+ yield minimax_message
diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py
index cc88d15736..1fab20ebbc 100644
--- a/api/core/model_runtime/model_providers/minimax/llm/llm.py
+++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py
@@ -34,6 +34,8 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
class MinimaxLargeLanguageModel(LargeLanguageModel):
model_apis = {
+ 'abab6.5s-chat': MinimaxChatCompletionPro,
+ 'abab6.5-chat': MinimaxChatCompletionPro,
'abab6-chat': MinimaxChatCompletionPro,
'abab5.5s-chat': MinimaxChatCompletionPro,
'abab5.5-chat': MinimaxChatCompletionPro,
diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py
index 44a1cf2e84..26c4199d16 100644
--- a/api/core/model_runtime/model_providers/model_provider_factory.py
+++ b/api/core/model_runtime/model_providers/model_provider_factory.py
@@ -4,13 +4,13 @@ from typing import Optional
from pydantic import BaseModel
+from core.helper.module_import_helper import load_single_subclass_from_source
+from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
-from core.utils.module_import_helper import load_single_subclass_from_source
-from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
logger = logging.getLogger(__name__)
diff --git a/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml b/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml
index fc69862722..2401f2a890 100644
--- a/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml
+++ b/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml
@@ -1,7 +1,11 @@
- google/gemma-7b
- google/codegemma-7b
+- google/recurrentgemma-2b
- meta/llama2-70b
- meta/llama3-8b-instruct
- meta/llama3-70b-instruct
+- mistralai/mistral-large
- mistralai/mixtral-8x7b-instruct-v0.1
+- mistralai/mixtral-8x22b-instruct-v0.1
- fuyu-8b
+- snowflake/arctic
diff --git a/api/core/model_runtime/model_providers/nvidia/llm/arctic.yaml b/api/core/model_runtime/model_providers/nvidia/llm/arctic.yaml
new file mode 100644
index 0000000000..7f53ae58e6
--- /dev/null
+++ b/api/core/model_runtime/model_providers/nvidia/llm/arctic.yaml
@@ -0,0 +1,36 @@
+model: snowflake/arctic
+label:
+ zh_Hans: snowflake/arctic
+ en_US: snowflake/arctic
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 4000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ min: 0
+ max: 1
+ default: 0.5
+ - name: top_p
+ use_template: top_p
+ min: 0
+ max: 1
+ default: 1
+ - name: max_tokens
+ use_template: max_tokens
+ min: 1
+ max: 1024
+ default: 1024
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ min: -2
+ max: 2
+ default: 0
+ - name: presence_penalty
+ use_template: presence_penalty
+ min: -2
+ max: 2
+ default: 0
diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py
index 402ffb2cf2..047bbeda63 100644
--- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py
+++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py
@@ -22,12 +22,16 @@ from core.model_runtime.utils import helper
class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
MODEL_SUFFIX_MAP = {
'fuyu-8b': 'vlm/adept/fuyu-8b',
+ 'mistralai/mistral-large': '',
'mistralai/mixtral-8x7b-instruct-v0.1': '',
+ 'mistralai/mixtral-8x22b-instruct-v0.1': '',
'google/gemma-7b': '',
'google/codegemma-7b': '',
+ 'snowflake/arctic':'',
'meta/llama2-70b': '',
'meta/llama3-8b-instruct': '',
- 'meta/llama3-70b-instruct': ''
+ 'meta/llama3-70b-instruct': '',
+ 'google/recurrentgemma-2b': ''
}
diff --git a/api/core/model_runtime/model_providers/nvidia/llm/mistral-large.yaml b/api/core/model_runtime/model_providers/nvidia/llm/mistral-large.yaml
new file mode 100644
index 0000000000..3e14d22141
--- /dev/null
+++ b/api/core/model_runtime/model_providers/nvidia/llm/mistral-large.yaml
@@ -0,0 +1,36 @@
+model: mistralai/mistral-large
+label:
+ zh_Hans: mistralai/mistral-large
+ en_US: mistralai/mistral-large
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 32000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ min: 0
+ max: 1
+ default: 0.5
+ - name: top_p
+ use_template: top_p
+ min: 0
+ max: 1
+ default: 1
+ - name: max_tokens
+ use_template: max_tokens
+ min: 1
+ max: 1024
+ default: 1024
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ min: -2
+ max: 2
+ default: 0
+ - name: presence_penalty
+ use_template: presence_penalty
+ min: -2
+ max: 2
+ default: 0
diff --git a/api/core/model_runtime/model_providers/nvidia/llm/mixtral-8x22b-instruct-v0.1.yaml b/api/core/model_runtime/model_providers/nvidia/llm/mixtral-8x22b-instruct-v0.1.yaml
new file mode 100644
index 0000000000..05500c0336
--- /dev/null
+++ b/api/core/model_runtime/model_providers/nvidia/llm/mixtral-8x22b-instruct-v0.1.yaml
@@ -0,0 +1,36 @@
+model: mistralai/mixtral-8x22b-instruct-v0.1
+label:
+ zh_Hans: mistralai/mixtral-8x22b-instruct-v0.1
+ en_US: mistralai/mixtral-8x22b-instruct-v0.1
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 64000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ min: 0
+ max: 1
+ default: 0.5
+ - name: top_p
+ use_template: top_p
+ min: 0
+ max: 1
+ default: 1
+ - name: max_tokens
+ use_template: max_tokens
+ min: 1
+ max: 1024
+ default: 1024
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ min: -2
+ max: 2
+ default: 0
+ - name: presence_penalty
+ use_template: presence_penalty
+ min: -2
+ max: 2
+ default: 0
diff --git a/api/core/model_runtime/model_providers/nvidia/llm/recurrentgemma-2b.yaml b/api/core/model_runtime/model_providers/nvidia/llm/recurrentgemma-2b.yaml
new file mode 100644
index 0000000000..73fcce3930
--- /dev/null
+++ b/api/core/model_runtime/model_providers/nvidia/llm/recurrentgemma-2b.yaml
@@ -0,0 +1,37 @@
+model: google/recurrentgemma-2b
+label:
+ zh_Hans: google/recurrentgemma-2b
+ en_US: google/recurrentgemma-2b
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 2048
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ min: 0
+ max: 1
+ default: 0.2
+ - name: top_p
+ use_template: top_p
+ min: 0
+ max: 1
+ default: 0.7
+ - name: max_tokens
+ use_template: max_tokens
+ min: 1
+ max: 1024
+ default: 1024
+ - name: random_seed
+ type: int
+ help:
+ en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
+ zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
+ label:
+ en_US: Seed
+ zh_Hans: 种子
+ default: 0
+ min: 0
+ max: 2147483647
diff --git a/api/core/model_runtime/model_providers/nvidia_nim/__init__.py b/api/core/model_runtime/model_providers/nvidia_nim/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_l_en.png
new file mode 100644
index 0000000000..5a7f42e617
Binary files /dev/null and b/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_l_en.png differ
diff --git a/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_s_en.svg
new file mode 100644
index 0000000000..9fc02f9164
--- /dev/null
+++ b/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_s_en.svg
@@ -0,0 +1,3 @@
+
diff --git a/api/core/model_runtime/model_providers/nvidia_nim/llm/__init__.py b/api/core/model_runtime/model_providers/nvidia_nim/llm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py
new file mode 100644
index 0000000000..f7b849fbe2
--- /dev/null
+++ b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py
@@ -0,0 +1,12 @@
+import logging
+
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+
+class NVIDIANIMProvider(OAIAPICompatLargeLanguageModel):
+ """
+ Model class for NVIDIA NIM large language model.
+ """
+ pass
diff --git a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py
new file mode 100644
index 0000000000..25ab3e8e20
--- /dev/null
+++ b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py
@@ -0,0 +1,11 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class NVIDIANIMProvider(ModelProvider):
+
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ pass
diff --git a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.yaml b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.yaml
new file mode 100644
index 0000000000..0e892665d7
--- /dev/null
+++ b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.yaml
@@ -0,0 +1,79 @@
+provider: nvidia_nim
+label:
+ en_US: NVIDIA NIM
+description:
+ en_US: NVIDIA NIM, a set of easy-to-use inference microservices.
+ zh_Hans: NVIDIA NIM,一组易于使用的模型推理微服务。
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.png
+background: "#EFFDFD"
+help:
+ title:
+ en_US: Learn more about NVIDIA NIM
+ zh_Hans: 了解 NVIDIA NIM 更多信息
+ url:
+ en_US: https://www.nvidia.com/en-us/ai/
+supported_model_types:
+ - llm
+configurate_methods:
+ - customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter full model name
+ zh_Hans: 输入模型全称
+ credential_form_schemas:
+ - variable: endpoint_url
+ label:
+ zh_Hans: API endpoint URL
+ en_US: API endpoint URL
+ type: text-input
+ required: true
+ placeholder:
+ zh_Hans: Base URL, e.g. http://192.168.1.100:8000/v1
+ en_US: Base URL, e.g. http://192.168.1.100:8000/v1
+ - variable: mode
+ show_on:
+ - variable: __model_type
+ value: llm
+ label:
+ en_US: Completion mode
+ type: select
+ required: false
+ default: chat
+ placeholder:
+ zh_Hans: 选择对话类型
+ en_US: Select completion mode
+ options:
+ - value: completion
+ label:
+ en_US: Completion
+ zh_Hans: 补全
+ - value: chat
+ label:
+ en_US: Chat
+ zh_Hans: 对话
+ - variable: context_size
+ label:
+ zh_Hans: 模型上下文长度
+ en_US: Model context size
+ required: true
+ type: text-input
+ default: '4096'
+ placeholder:
+ zh_Hans: 在此输入您的模型上下文长度
+ en_US: Enter your Model context size
+ - variable: max_tokens_to_sample
+ label:
+ zh_Hans: 最大 token 上限
+ en_US: Upper bound for max tokens
+ show_on:
+ - variable: __model_type
+ value: llm
+ default: '4096'
+ type: text-input
diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py
index fcb94084a5..dd58a563ab 100644
--- a/api/core/model_runtime/model_providers/ollama/llm/llm.py
+++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py
@@ -8,7 +8,12 @@ from urllib.parse import urljoin
import requests
-from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.llm_entities import (
+ LLMMode,
+ LLMResult,
+ LLMResultChunk,
+ LLMResultChunkDelta,
+)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
@@ -40,7 +45,9 @@ from core.model_runtime.errors.invoke import (
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.__base.large_language_model import (
+ LargeLanguageModel,
+)
logger = logging.getLogger(__name__)
@@ -50,11 +57,17 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
Model class for Ollama large language model.
"""
- def _invoke(self, model: str, credentials: dict,
- prompt_messages: list[PromptMessage], model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
- stream: bool = True, user: Optional[str] = None) \
- -> Union[LLMResult, Generator]:
+ def _invoke(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: list[PromptMessage],
+ model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None,
+ stop: Optional[list[str]] = None,
+ stream: bool = True,
+ user: Optional[str] = None,
+ ) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@@ -75,11 +88,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
model_parameters=model_parameters,
stop=stop,
stream=stream,
- user=user
+ user=user,
)
- def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None) -> int:
+ def get_num_tokens(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None,
+ ) -> int:
"""
Get number of tokens for given prompt messages
@@ -100,10 +118,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if isinstance(first_prompt_message.content, str):
text = first_prompt_message.content
else:
- text = ''
+ text = ""
for message_content in first_prompt_message.content:
if message_content.type == PromptMessageContentType.TEXT:
- message_content = cast(TextPromptMessageContent, message_content)
+ message_content = cast(
+ TextPromptMessageContent, message_content
+ )
text = message_content.data
break
return self._get_num_tokens_by_gpt2(text)
@@ -121,19 +141,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
model=model,
credentials=credentials,
prompt_messages=[UserPromptMessage(content="ping")],
- model_parameters={
- 'num_predict': 5
- },
- stream=False
+ model_parameters={"num_predict": 5},
+ stream=False,
)
except InvokeError as ex:
- raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
+ raise CredentialsValidateFailedError(
+ f"An error occurred during credentials validation: {ex.description}"
+ )
except Exception as ex:
- raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
+ raise CredentialsValidateFailedError(
+ f"An error occurred during credentials validation: {str(ex)}"
+ )
- def _generate(self, model: str, credentials: dict,
- prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
- stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+ def _generate(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: list[PromptMessage],
+ model_parameters: dict,
+ stop: Optional[list[str]] = None,
+ stream: bool = True,
+ user: Optional[str] = None,
+ ) -> Union[LLMResult, Generator]:
"""
Invoke llm completion model
@@ -146,76 +175,93 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
- headers = {
- 'Content-Type': 'application/json'
- }
+ headers = {"Content-Type": "application/json"}
- endpoint_url = credentials['base_url']
- if not endpoint_url.endswith('/'):
- endpoint_url += '/'
+ endpoint_url = credentials["base_url"]
+ if not endpoint_url.endswith("/"):
+ endpoint_url += "/"
# prepare the payload for a simple ping to the model
- data = {
- 'model': model,
- 'stream': stream
- }
+ data = {"model": model, "stream": stream}
- if 'format' in model_parameters:
- data['format'] = model_parameters['format']
- del model_parameters['format']
+ if "format" in model_parameters:
+ data["format"] = model_parameters["format"]
+ del model_parameters["format"]
- data['options'] = model_parameters or {}
+ if "keep_alive" in model_parameters:
+ data["keep_alive"] = model_parameters["keep_alive"]
+ del model_parameters["keep_alive"]
+
+ data["options"] = model_parameters or {}
if stop:
- data['stop'] = "\n".join(stop)
+ data["stop"] = "\n".join(stop)
- completion_type = LLMMode.value_of(credentials['mode'])
+ completion_type = LLMMode.value_of(credentials["mode"])
if completion_type is LLMMode.CHAT:
- endpoint_url = urljoin(endpoint_url, 'api/chat')
- data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
+ endpoint_url = urljoin(endpoint_url, "api/chat")
+ data["messages"] = [
+ self._convert_prompt_message_to_dict(m) for m in prompt_messages
+ ]
else:
- endpoint_url = urljoin(endpoint_url, 'api/generate')
+ endpoint_url = urljoin(endpoint_url, "api/generate")
first_prompt_message = prompt_messages[0]
if isinstance(first_prompt_message, UserPromptMessage):
first_prompt_message = cast(UserPromptMessage, first_prompt_message)
if isinstance(first_prompt_message.content, str):
- data['prompt'] = first_prompt_message.content
+ data["prompt"] = first_prompt_message.content
else:
- text = ''
+ text = ""
images = []
for message_content in first_prompt_message.content:
if message_content.type == PromptMessageContentType.TEXT:
- message_content = cast(TextPromptMessageContent, message_content)
+ message_content = cast(
+ TextPromptMessageContent, message_content
+ )
text = message_content.data
elif message_content.type == PromptMessageContentType.IMAGE:
- message_content = cast(ImagePromptMessageContent, message_content)
- image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
+ message_content = cast(
+ ImagePromptMessageContent, message_content
+ )
+ image_data = re.sub(
+ r"^data:image\/[a-zA-Z]+;base64,",
+ "",
+ message_content.data,
+ )
images.append(image_data)
- data['prompt'] = text
- data['images'] = images
+ data["prompt"] = text
+ data["images"] = images
# send a post request to validate the credentials
response = requests.post(
- endpoint_url,
- headers=headers,
- json=data,
- timeout=(10, 300),
- stream=stream
+ endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream
)
response.encoding = "utf-8"
if response.status_code != 200:
- raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
+ raise InvokeError(
+ f"API request failed with status code {response.status_code}: {response.text}"
+ )
if stream:
- return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
+ return self._handle_generate_stream_response(
+ model, credentials, completion_type, response, prompt_messages
+ )
- return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
+ return self._handle_generate_response(
+ model, credentials, completion_type, response, prompt_messages
+ )
- def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode,
- response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult:
+ def _handle_generate_response(
+ self,
+ model: str,
+ credentials: dict,
+ completion_type: LLMMode,
+ response: requests.Response,
+ prompt_messages: list[PromptMessage],
+ ) -> LLMResult:
"""
Handle llm completion response
@@ -229,14 +275,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
response_json = response.json()
if completion_type is LLMMode.CHAT:
- message = response_json.get('message', {})
- response_content = message.get('content', '')
+ message = response_json.get("message", {})
+ response_content = message.get("content", "")
else:
- response_content = response_json['response']
+ response_content = response_json["response"]
assistant_message = AssistantPromptMessage(content=response_content)
- if 'prompt_eval_count' in response_json and 'eval_count' in response_json:
+ if "prompt_eval_count" in response_json and "eval_count" in response_json:
# transform usage
prompt_tokens = response_json["prompt_eval_count"]
completion_tokens = response_json["eval_count"]
@@ -246,7 +292,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
# transform usage
- usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+ usage = self._calc_response_usage(
+ model, credentials, prompt_tokens, completion_tokens
+ )
# transform response
result = LLMResult(
@@ -258,8 +306,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return result
- def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode,
- response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator:
+ def _handle_generate_stream_response(
+ self,
+ model: str,
+ credentials: dict,
+ completion_type: LLMMode,
+ response: requests.Response,
+ prompt_messages: list[PromptMessage],
+ ) -> Generator:
"""
Handle llm completion stream response
@@ -270,17 +324,20 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
- full_text = ''
+ full_text = ""
chunk_index = 0
- def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
- -> LLMResultChunk:
+ def create_final_llm_result_chunk(
+ index: int, message: AssistantPromptMessage, finish_reason: str
+ ) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
# transform usage
- usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+ usage = self._calc_response_usage(
+ model, credentials, prompt_tokens, completion_tokens
+ )
return LLMResultChunk(
model=model,
@@ -289,11 +346,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
index=index,
message=message,
finish_reason=finish_reason,
- usage=usage
- )
+ usage=usage,
+ ),
)
- for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'):
+ for chunk in response.iter_lines(decode_unicode=True, delimiter="\n"):
if not chunk:
continue
@@ -304,7 +361,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
yield create_final_llm_result_chunk(
index=chunk_index,
message=AssistantPromptMessage(content=""),
- finish_reason="Non-JSON encountered."
+ finish_reason="Non-JSON encountered.",
)
chunk_index += 1
@@ -314,55 +371,57 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if not chunk_json:
continue
- if 'message' not in chunk_json:
- text = ''
+ if "message" not in chunk_json:
+ text = ""
else:
- text = chunk_json.get('message').get('content', '')
+ text = chunk_json.get("message").get("content", "")
else:
if not chunk_json:
continue
# transform assistant message to prompt message
- text = chunk_json['response']
+ text = chunk_json["response"]
- assistant_prompt_message = AssistantPromptMessage(
- content=text
- )
+ assistant_prompt_message = AssistantPromptMessage(content=text)
full_text += text
- if chunk_json['done']:
+ if chunk_json["done"]:
# calculate num tokens
- if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json:
+ if "prompt_eval_count" in chunk_json and "eval_count" in chunk_json:
# transform usage
prompt_tokens = chunk_json["prompt_eval_count"]
completion_tokens = chunk_json["eval_count"]
else:
# calculate num tokens
- prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
+ prompt_tokens = self._get_num_tokens_by_gpt2(
+ prompt_messages[0].content
+ )
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
# transform usage
- usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+ usage = self._calc_response_usage(
+ model, credentials, prompt_tokens, completion_tokens
+ )
yield LLMResultChunk(
- model=chunk_json['model'],
+ model=chunk_json["model"],
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
- finish_reason='stop',
- usage=usage
- )
+ finish_reason="stop",
+ usage=usage,
+ ),
)
else:
yield LLMResultChunk(
- model=chunk_json['model'],
+ model=chunk_json["model"],
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
- )
+ ),
)
chunk_index += 1
@@ -376,15 +435,21 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
- text = ''
+ text = ""
images = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
- message_content = cast(TextPromptMessageContent, message_content)
+ message_content = cast(
+ TextPromptMessageContent, message_content
+ )
text = message_content.data
elif message_content.type == PromptMessageContentType.IMAGE:
- message_content = cast(ImagePromptMessageContent, message_content)
- image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
+ message_content = cast(
+ ImagePromptMessageContent, message_content
+ )
+ image_data = re.sub(
+ r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data
+ )
images.append(image_data)
message_dict = {"role": "user", "content": text, "images": images}
@@ -414,7 +479,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return num_tokens
- def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ def get_customizable_model_schema(
+ self, model: str, credentials: dict
+ ) -> AIModelEntity:
"""
Get customizable model schema.
@@ -425,20 +492,19 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
"""
extras = {}
- if 'vision_support' in credentials and credentials['vision_support'] == 'true':
- extras['features'] = [ModelFeature.VISION]
+ if "vision_support" in credentials and credentials["vision_support"] == "true":
+ extras["features"] = [ModelFeature.VISION]
entity = AIModelEntity(
model=model,
- label=I18nObject(
- zh_Hans=model,
- en_US=model
- ),
+ label=I18nObject(zh_Hans=model, en_US=model),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
- ModelPropertyKey.MODE: credentials.get('mode'),
- ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
+ ModelPropertyKey.MODE: credentials.get("mode"),
+ ModelPropertyKey.CONTEXT_SIZE: int(
+ credentials.get("context_size", 4096)
+ ),
},
parameter_rules=[
ParameterRule(
@@ -446,152 +512,195 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
use_template=DefaultParameterName.TEMPERATURE.value,
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
- help=I18nObject(en_US="The temperature of the model. "
- "Increasing the temperature will make the model answer "
- "more creatively. (Default: 0.8)"),
+ help=I18nObject(
+ en_US="The temperature of the model. "
+ "Increasing the temperature will make the model answer "
+ "more creatively. (Default: 0.8)"
+ ),
default=0.1,
min=0,
- max=2
+ max=1,
),
ParameterRule(
name=DefaultParameterName.TOP_P.value,
use_template=DefaultParameterName.TOP_P.value,
label=I18nObject(en_US="Top P"),
type=ParameterType.FLOAT,
- help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
- "more diverse text, while a lower value (e.g., 0.5) will generate more "
- "focused and conservative text. (Default: 0.9)"),
+ help=I18nObject(
+ en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
+ "more diverse text, while a lower value (e.g., 0.5) will generate more "
+ "focused and conservative text. (Default: 0.9)"
+ ),
default=0.9,
min=0,
- max=1
+ max=1,
),
ParameterRule(
name="top_k",
label=I18nObject(en_US="Top K"),
type=ParameterType.INT,
- help=I18nObject(en_US="Reduces the probability of generating nonsense. "
- "A higher value (e.g. 100) will give more diverse answers, "
- "while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
+ help=I18nObject(
+ en_US="Reduces the probability of generating nonsense. "
+ "A higher value (e.g. 100) will give more diverse answers, "
+ "while a lower value (e.g. 10) will be more conservative. (Default: 40)"
+ ),
min=1,
- max=100
+ max=100,
),
ParameterRule(
- name='repeat_penalty',
+ name="repeat_penalty",
label=I18nObject(en_US="Repeat Penalty"),
type=ParameterType.FLOAT,
- help=I18nObject(en_US="Sets how strongly to penalize repetitions. "
- "A higher value (e.g., 1.5) will penalize repetitions more strongly, "
- "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"),
+ help=I18nObject(
+ en_US="Sets how strongly to penalize repetitions. "
+ "A higher value (e.g., 1.5) will penalize repetitions more strongly, "
+ "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
+ ),
min=-2,
- max=2
+ max=2,
),
ParameterRule(
- name='num_predict',
- use_template='max_tokens',
+ name="num_predict",
+ use_template="max_tokens",
label=I18nObject(en_US="Num Predict"),
type=ParameterType.INT,
- help=I18nObject(en_US="Maximum number of tokens to predict when generating text. "
- "(Default: 128, -1 = infinite generation, -2 = fill context)"),
- default=512 if int(credentials.get('max_tokens', 4096)) >= 768 else 128,
+ help=I18nObject(
+ en_US="Maximum number of tokens to predict when generating text. "
+ "(Default: 128, -1 = infinite generation, -2 = fill context)"
+ ),
+ default=(
+ 512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128
+ ),
min=-2,
- max=int(credentials.get('max_tokens', 4096)),
+ max=int(credentials.get("max_tokens", 4096)),
),
ParameterRule(
- name='mirostat',
+ name="mirostat",
label=I18nObject(en_US="Mirostat sampling"),
type=ParameterType.INT,
- help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. "
- "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"),
+ help=I18nObject(
+ en_US="Enable Mirostat sampling for controlling perplexity. "
+ "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
+ ),
min=0,
- max=2
+ max=2,
),
ParameterRule(
- name='mirostat_eta',
+ name="mirostat_eta",
label=I18nObject(en_US="Mirostat Eta"),
type=ParameterType.FLOAT,
- help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from "
- "the generated text. A lower learning rate will result in slower adjustments, "
- "while a higher learning rate will make the algorithm more responsive. "
- "(Default: 0.1)"),
- precision=1
+ help=I18nObject(
+ en_US="Influences how quickly the algorithm responds to feedback from "
+ "the generated text. A lower learning rate will result in slower adjustments, "
+ "while a higher learning rate will make the algorithm more responsive. "
+ "(Default: 0.1)"
+ ),
+ precision=1,
),
ParameterRule(
- name='mirostat_tau',
+ name="mirostat_tau",
label=I18nObject(en_US="Mirostat Tau"),
type=ParameterType.FLOAT,
- help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. "
- "A lower value will result in more focused and coherent text. (Default: 5.0)"),
- precision=1
+ help=I18nObject(
+ en_US="Controls the balance between coherence and diversity of the output. "
+ "A lower value will result in more focused and coherent text. (Default: 5.0)"
+ ),
+ precision=1,
),
ParameterRule(
- name='num_ctx',
+ name="num_ctx",
label=I18nObject(en_US="Size of context window"),
type=ParameterType.INT,
- help=I18nObject(en_US="Sets the size of the context window used to generate the next token. "
- "(Default: 2048)"),
+ help=I18nObject(
+ en_US="Sets the size of the context window used to generate the next token. "
+ "(Default: 2048)"
+ ),
default=2048,
- min=1
+ min=1,
),
ParameterRule(
name='num_gpu',
- label=I18nObject(en_US="Num GPU"),
+ label=I18nObject(en_US="GPU Layers"),
type=ParameterType.INT,
- help=I18nObject(en_US="The number of layers to send to the GPU(s). "
- "On macOS it defaults to 1 to enable metal support, 0 to disable."),
- min=0,
- max=1
+ help=I18nObject(en_US="The number of layers to offload to the GPU(s). "
+ "On macOS it defaults to 1 to enable metal support, 0 to disable."
+ "As long as a model fits into one gpu it stays in one. "
+ "It does not set the number of GPU(s). "),
+ min=-1,
+ default=1
),
ParameterRule(
- name='num_thread',
+ name="num_thread",
label=I18nObject(en_US="Num Thread"),
type=ParameterType.INT,
- help=I18nObject(en_US="Sets the number of threads to use during computation. "
- "By default, Ollama will detect this for optimal performance. "
- "It is recommended to set this value to the number of physical CPU cores "
- "your system has (as opposed to the logical number of cores)."),
+ help=I18nObject(
+ en_US="Sets the number of threads to use during computation. "
+ "By default, Ollama will detect this for optimal performance. "
+ "It is recommended to set this value to the number of physical CPU cores "
+ "your system has (as opposed to the logical number of cores)."
+ ),
min=1,
),
ParameterRule(
- name='repeat_last_n',
+ name="repeat_last_n",
label=I18nObject(en_US="Repeat last N"),
type=ParameterType.INT,
- help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. "
- "(Default: 64, 0 = disabled, -1 = num_ctx)"),
- min=-1
+ help=I18nObject(
+ en_US="Sets how far back for the model to look back to prevent repetition. "
+ "(Default: 64, 0 = disabled, -1 = num_ctx)"
+ ),
+ min=-1,
),
ParameterRule(
- name='tfs_z',
+ name="tfs_z",
label=I18nObject(en_US="TFS Z"),
type=ParameterType.FLOAT,
- help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens "
- "from the output. A higher value (e.g., 2.0) will reduce the impact more, "
- "while a value of 1.0 disables this setting. (default: 1)"),
- precision=1
+ help=I18nObject(
+ en_US="Tail free sampling is used to reduce the impact of less probable tokens "
+ "from the output. A higher value (e.g., 2.0) will reduce the impact more, "
+ "while a value of 1.0 disables this setting. (default: 1)"
+ ),
+ precision=1,
),
ParameterRule(
- name='seed',
+ name="seed",
label=I18nObject(en_US="Seed"),
type=ParameterType.INT,
- help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to "
- "a specific number will make the model generate the same text for "
- "the same prompt. (Default: 0)"),
+ help=I18nObject(
+ en_US="Sets the random number seed to use for generation. Setting this to "
+ "a specific number will make the model generate the same text for "
+ "the same prompt. (Default: 0)"
+ ),
),
ParameterRule(
- name='format',
+ name="keep_alive",
+ label=I18nObject(en_US="Keep Alive"),
+ type=ParameterType.STRING,
+ help=I18nObject(
+ en_US="Sets how long the model is kept in memory after generating a response. "
+ "This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours). "
+ "A negative number keeps the model loaded indefinitely, and '0' unloads the model immediately after generating a response. "
+ "Valid time units are 's','m','h'. (Default: 5m)"
+ ),
+ ),
+ ParameterRule(
+ name="format",
label=I18nObject(en_US="Format"),
type=ParameterType.STRING,
- help=I18nObject(en_US="the format to return a response in."
- " Currently the only accepted value is json."),
- options=['json'],
- )
+ help=I18nObject(
+ en_US="the format to return a response in."
+ " Currently the only accepted value is json."
+ ),
+ options=["json"],
+ ),
],
pricing=PriceConfig(
- input=Decimal(credentials.get('input_price', 0)),
- output=Decimal(credentials.get('output_price', 0)),
- unit=Decimal(credentials.get('unit', 0)),
- currency=credentials.get('currency', "USD")
+ input=Decimal(credentials.get("input_price", 0)),
+ output=Decimal(credentials.get("output_price", 0)),
+ unit=Decimal(credentials.get("unit", 0)),
+ currency=credentials.get("currency", "USD"),
),
- **extras
+ **extras,
)
return entity
@@ -619,10 +728,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
],
InvokeServerUnavailableError: [
requests.exceptions.ConnectionError, # Engine Overloaded
- requests.exceptions.HTTPError # Server Error
+ requests.exceptions.HTTPError, # Server Error
],
InvokeConnectionError: [
requests.exceptions.ConnectTimeout, # Timeout
- requests.exceptions.ReadTimeout # Timeout
- ]
+ requests.exceptions.ReadTimeout, # Timeout
+ ],
}
diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py
index 436461c11e..5772f325e1 100644
--- a/api/core/model_runtime/model_providers/openai/_common.py
+++ b/api/core/model_runtime/model_providers/openai/_common.py
@@ -25,7 +25,7 @@ class _CommonOpenAI:
"max_retries": 1,
}
- if 'openai_api_base' in credentials and credentials['openai_api_base']:
+ if credentials.get('openai_api_base'):
credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/')
credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1'
diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml
index 3808d670c3..566055e3f7 100644
--- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml
+++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml
@@ -1,4 +1,6 @@
- gpt-4
+- gpt-4o
+- gpt-4o-2024-05-13
- gpt-4-turbo
- gpt-4-turbo-2024-04-09
- gpt-4-turbo-preview
diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml
new file mode 100644
index 0000000000..f0d835cba2
--- /dev/null
+++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml
@@ -0,0 +1,44 @@
+model: gpt-4o-2024-05-13
+label:
+ zh_Hans: gpt-4o-2024-05-13
+ en_US: gpt-4o-2024-05-13
+model_type: llm
+features:
+ - multi-tool-call
+ - agent-thought
+ - stream-tool-call
+ - vision
+model_properties:
+ mode: chat
+ context_size: 128000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ - name: max_tokens
+ use_template: max_tokens
+ default: 512
+ min: 1
+ max: 4096
+ - name: response_format
+ label:
+ zh_Hans: 回复格式
+ en_US: response_format
+ type: string
+ help:
+ zh_Hans: 指定模型必须输出的格式
+ en_US: specifying the format that the model must output
+ required: false
+ options:
+ - text
+ - json_object
+pricing:
+ input: '5.00'
+ output: '15.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml
new file mode 100644
index 0000000000..4f141f772f
--- /dev/null
+++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml
@@ -0,0 +1,44 @@
+model: gpt-4o
+label:
+ zh_Hans: gpt-4o
+ en_US: gpt-4o
+model_type: llm
+features:
+ - multi-tool-call
+ - agent-thought
+ - stream-tool-call
+ - vision
+model_properties:
+ mode: chat
+ context_size: 128000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ - name: max_tokens
+ use_template: max_tokens
+ default: 512
+ min: 1
+ max: 4096
+ - name: response_format
+ label:
+ zh_Hans: 回复格式
+ en_US: response_format
+ type: string
+ help:
+ zh_Hans: 指定模型必须输出的格式
+ en_US: specifying the format that the model must output
+ required: false
+ options:
+ - text
+ - json_object
+pricing:
+ input: '5.00'
+ output: '15.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py
index b7db39376c..69afabadb3 100644
--- a/api/core/model_runtime/model_providers/openai/llm/llm.py
+++ b/api/core/model_runtime/model_providers/openai/llm/llm.py
@@ -378,6 +378,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if user:
extra_model_kwargs['user'] = user
+ if stream:
+ extra_model_kwargs['stream_options'] = {
+ "include_usage": True
+ }
+
# text completion model
response = client.completions.create(
prompt=prompt_messages[0].content,
@@ -446,8 +451,24 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
:return: llm response chunk generator result
"""
full_text = ''
+ prompt_tokens = 0
+ completion_tokens = 0
+
+ final_chunk = LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(content=''),
+ )
+ )
+
for chunk in response:
if len(chunk.choices) == 0:
+ if chunk.usage:
+ # calculate num tokens
+ prompt_tokens = chunk.usage.prompt_tokens
+ completion_tokens = chunk.usage.completion_tokens
continue
delta = chunk.choices[0]
@@ -464,20 +485,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
full_text += text
if delta.finish_reason is not None:
- # calculate num tokens
- if chunk.usage:
- # transform usage
- prompt_tokens = chunk.usage.prompt_tokens
- completion_tokens = chunk.usage.completion_tokens
- else:
- # calculate num tokens
- prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
- completion_tokens = self._num_tokens_from_string(model, full_text)
-
- # transform usage
- usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
-
- yield LLMResultChunk(
+ final_chunk = LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
@@ -485,7 +493,6 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
- usage=usage
)
)
else:
@@ -499,6 +506,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
)
)
+ if not prompt_tokens:
+ prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
+
+ if not completion_tokens:
+ completion_tokens = self._num_tokens_from_string(model, full_text)
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ final_chunk.delta.usage = usage
+
+ yield final_chunk
+
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
@@ -531,6 +551,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model_parameters["response_format"] = response_format
+
extra_model_kwargs = {}
if tools:
@@ -547,6 +568,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if user:
extra_model_kwargs['user'] = user
+ if stream:
+ extra_model_kwargs['stream_options'] = {
+ 'include_usage': True
+ }
+
# clear illegal prompt messages
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
@@ -630,8 +656,24 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
"""
full_assistant_content = ''
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
+ prompt_tokens = 0
+ completion_tokens = 0
+ final_tool_calls = []
+ final_chunk = LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(content=''),
+ )
+ )
+
for chunk in response:
if len(chunk.choices) == 0:
+ if chunk.usage:
+ # calculate num tokens
+ prompt_tokens = chunk.usage.prompt_tokens
+ completion_tokens = chunk.usage.completion_tokens
continue
delta = chunk.choices[0]
@@ -667,6 +709,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call)
tool_calls = [function_call] if function_call else []
+ if tool_calls:
+ final_tool_calls.extend(tool_calls)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
@@ -677,19 +721,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
full_assistant_content += delta.delta.content if delta.delta.content else ''
if has_finish_reason:
- # calculate num tokens
- prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
-
- full_assistant_prompt_message = AssistantPromptMessage(
- content=full_assistant_content,
- tool_calls=tool_calls
- )
- completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message])
-
- # transform usage
- usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
-
- yield LLMResultChunk(
+ final_chunk = LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
@@ -697,7 +729,6 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
- usage=usage
)
)
else:
@@ -711,6 +742,22 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
)
)
+ if not prompt_tokens:
+ prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
+
+ if not completion_tokens:
+ full_assistant_prompt_message = AssistantPromptMessage(
+ content=full_assistant_content,
+ tool_calls=final_tool_calls
+ )
+ completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message])
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+ final_chunk.delta.usage = usage
+
+ yield final_chunk
+
def _extract_response_tool_calls(self,
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
-> list[AssistantPromptMessage.ToolCall]:
diff --git a/api/core/model_runtime/model_providers/openai/openai.yaml b/api/core/model_runtime/model_providers/openai/openai.yaml
index 3af99e107e..b4dc8fd4f2 100644
--- a/api/core/model_runtime/model_providers/openai/openai.yaml
+++ b/api/core/model_runtime/model_providers/openai/openai.yaml
@@ -85,5 +85,5 @@ provider_credential_schema:
type: text-input
required: false
placeholder:
- zh_Hans: 在此输入您的 API Base
- en_US: Enter your API Base
+ zh_Hans: 在此输入您的 API Base, 如:https://api.openai.com
+ en_US: Enter your API Base, e.g. https://api.openai.com
diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py
index 43258d1e5e..1c3f084207 100644
--- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py
+++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py
@@ -180,7 +180,7 @@ class OpenLLMGenerate:
completion_usage += len(token_ids)
message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value)
- if 'finish_reason' in choice and choice['finish_reason']:
+ if choice.get('finish_reason'):
finish_reason = choice['finish_reason']
prompt_token_ids = data['prompt_token_ids']
message.stop_reason = finish_reason
diff --git a/api/core/model_runtime/model_providers/vertex_ai/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png
new file mode 100644
index 0000000000..9f8f05231a
Binary files /dev/null and b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png differ
diff --git a/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg
new file mode 100644
index 0000000000..efc3589c07
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/vertex_ai/_common.py b/api/core/model_runtime/model_providers/vertex_ai/_common.py
new file mode 100644
index 0000000000..8f7c859e38
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/_common.py
@@ -0,0 +1,15 @@
+from core.model_runtime.errors.invoke import InvokeError
+
+
+class _CommonVertexAi:
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ pass
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-haiku.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-haiku.yaml
new file mode 100644
index 0000000000..5613348695
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-haiku.yaml
@@ -0,0 +1,56 @@
+model: claude-3-haiku@20240307
+label:
+ en_US: Claude 3 Haiku
+model_type: llm
+features:
+ - agent-thought
+ - vision
+model_properties:
+ mode: chat
+ context_size: 200000
+parameter_rules:
+ - name: max_tokens
+ use_template: max_tokens
+ required: true
+ type: int
+ default: 4096
+ min: 1
+ max: 4096
+ help:
+ zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
+ en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
+ # docs: https://docs.anthropic.com/claude/docs/system-prompts
+ - name: temperature
+ use_template: temperature
+ required: false
+ type: float
+ default: 1
+ min: 0.0
+ max: 1.0
+ help:
+ zh_Hans: 生成内容的随机性。
+ en_US: The amount of randomness injected into the response.
+ - name: top_p
+ required: false
+ type: float
+ default: 0.999
+ min: 0.000
+ max: 1.000
+ help:
+ zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
+ en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
+ - name: top_k
+ required: false
+ type: int
+ default: 0
+ min: 0
+ # tip docs from aws has error, max value is 500
+ max: 500
+ help:
+ zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
+ en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
+pricing:
+ input: '0.00025'
+ output: '0.00125'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-opus.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-opus.yaml
new file mode 100644
index 0000000000..ab084636b5
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-opus.yaml
@@ -0,0 +1,56 @@
+model: claude-3-opus@20240229
+label:
+ en_US: Claude 3 Opus
+model_type: llm
+features:
+ - agent-thought
+ - vision
+model_properties:
+ mode: chat
+ context_size: 200000
+parameter_rules:
+ - name: max_tokens
+ use_template: max_tokens
+ required: true
+ type: int
+ default: 4096
+ min: 1
+ max: 4096
+ help:
+ zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
+ en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
+ # docs: https://docs.anthropic.com/claude/docs/system-prompts
+ - name: temperature
+ use_template: temperature
+ required: false
+ type: float
+ default: 1
+ min: 0.0
+ max: 1.0
+ help:
+ zh_Hans: 生成内容的随机性。
+ en_US: The amount of randomness injected into the response.
+ - name: top_p
+ required: false
+ type: float
+ default: 0.999
+ min: 0.000
+ max: 1.000
+ help:
+ zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
+ en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
+ - name: top_k
+ required: false
+ type: int
+ default: 0
+ min: 0
+ # tip docs from aws has error, max value is 500
+ max: 500
+ help:
+ zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
+ en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
+pricing:
+ input: '0.015'
+ output: '0.075'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-sonnet.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-sonnet.yaml
new file mode 100644
index 0000000000..0be0113ffd
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-sonnet.yaml
@@ -0,0 +1,55 @@
+model: claude-3-sonnet@20240229
+label:
+ en_US: Claude 3 Sonnet
+model_type: llm
+features:
+ - agent-thought
+ - vision
+model_properties:
+ mode: chat
+ context_size: 200000
+parameter_rules:
+ - name: max_tokens
+ use_template: max_tokens
+ required: true
+ type: int
+ default: 4096
+ min: 1
+ max: 4096
+ help:
+ zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
+ en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
+ - name: temperature
+ use_template: temperature
+ required: false
+ type: float
+ default: 1
+ min: 0.0
+ max: 1.0
+ help:
+ zh_Hans: 生成内容的随机性。
+ en_US: The amount of randomness injected into the response.
+ - name: top_p
+ required: false
+ type: float
+ default: 0.999
+ min: 0.000
+ max: 1.000
+ help:
+ zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
+ en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
+ - name: top_k
+ required: false
+ type: int
+ default: 0
+ min: 0
+ # tip docs from aws has error, max value is 500
+ max: 500
+ help:
+ zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
+ en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
+pricing:
+ input: '0.003'
+ output: '0.015'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml
new file mode 100644
index 0000000000..da3bc8a64a
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml
@@ -0,0 +1,38 @@
+model: gemini-1.0-pro-vision-001
+label:
+ en_US: Gemini 1.0 Pro Vision
+model_type: llm
+features:
+ - vision
+ - tool-call
+ - stream-tool-call
+model_properties:
+ mode: chat
+ context_size: 16384
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ en_US: Top k
+ type: int
+ help:
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ - name: max_output_tokens
+ use_template: max_tokens
+ required: true
+ default: 2048
+ min: 1
+ max: 2048
+pricing:
+ input: '0.00'
+ output: '0.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml
new file mode 100644
index 0000000000..029fab718c
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml
@@ -0,0 +1,38 @@
+model: gemini-1.0-pro-002
+label:
+ en_US: Gemini 1.0 Pro
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+ - stream-tool-call
+model_properties:
+ mode: chat
+ context_size: 32760
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ en_US: Top k
+ type: int
+ help:
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ - name: max_output_tokens
+ use_template: max_tokens
+ required: true
+ default: 8192
+ min: 1
+ max: 8192
+pricing:
+ input: '0.00'
+ output: '0.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml
new file mode 100644
index 0000000000..72b8410aa1
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml
@@ -0,0 +1,38 @@
+model: gemini-1.5-flash-preview-0514
+label:
+ en_US: Gemini 1.5 Flash
+model_type: llm
+features:
+ - vision
+ - tool-call
+ - stream-tool-call
+model_properties:
+ mode: chat
+ context_size: 1048576
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ en_US: Top k
+ type: int
+ help:
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ - name: max_output_tokens
+ use_template: max_tokens
+ required: true
+ default: 8192
+ min: 1
+ max: 8192
+pricing:
+ input: '0.00'
+ output: '0.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml
new file mode 100644
index 0000000000..141f61aad6
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml
@@ -0,0 +1,39 @@
+model: gemini-1.5-pro-preview-0514
+label:
+ en_US: Gemini 1.5 Pro
+model_type: llm
+features:
+ - agent-thought
+ - vision
+ - tool-call
+ - stream-tool-call
+model_properties:
+ mode: chat
+ context_size: 1048576
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ en_US: Top k
+ type: int
+ help:
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ - name: max_output_tokens
+ use_template: max_tokens
+ required: true
+ default: 8192
+ min: 1
+ max: 8192
+pricing:
+ input: '0.00'
+ output: '0.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
new file mode 100644
index 0000000000..0d6dd8d982
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
@@ -0,0 +1,728 @@
+import base64
+import json
+import logging
+from collections.abc import Generator
+from typing import Optional, Union, cast
+
+import google.api_core.exceptions as exceptions
+import vertexai.generative_models as glm
+from anthropic import AnthropicVertex, Stream
+from anthropic.types import (
+ ContentBlockDeltaEvent,
+ Message,
+ MessageDeltaEvent,
+ MessageStartEvent,
+ MessageStopEvent,
+ MessageStreamEvent,
+)
+from google.cloud import aiplatform
+from google.oauth2 import service_account
+from vertexai.generative_models import HarmBlockThreshold, HarmCategory
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ ImagePromptMessageContent,
+ PromptMessage,
+ PromptMessageContentType,
+ PromptMessageTool,
+ SystemPromptMessage,
+ TextPromptMessageContent,
+ ToolPromptMessage,
+ UserPromptMessage,
+)
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
+
+
+{{instructions}}
+
+"""
+
+
+class VertexAiLargeLanguageModel(LargeLanguageModel):
+
+ def _invoke(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ stream: bool = True, user: Optional[str] = None) \
+ -> Union[LLMResult, Generator]:
+ """
+ Invoke large language model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param tools: tools for tool calling
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+ # invoke anthropic models via anthropic official SDK
+ if "claude" in model:
+ return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
+ # invoke Gemini model
+ return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+ def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
+ stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+ """
+ Invoke Anthropic large language model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param stop: stop words
+ :param stream: is stream response
+ :return: full response or stream response chunk generator result
+ """
+ # use Anthropic official SDK references
+ # - https://github.com/anthropics/anthropic-sdk-python
+ project_id = credentials["vertex_project_id"]
+
+ if 'opus' in model:
+ location = 'us-east5'
+ else:
+ location = 'us-central1'
+
+ client = AnthropicVertex(
+ region=location,
+ project_id=project_id
+ )
+
+ extra_model_kwargs = {}
+ if stop:
+ extra_model_kwargs['stop_sequences'] = stop
+
+ system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages)
+
+ if system:
+ extra_model_kwargs['system'] = system
+
+ response = client.messages.create(
+ model=model,
+ messages=prompt_message_dicts,
+ stream=stream,
+ **model_parameters,
+ **extra_model_kwargs
+ )
+
+ if stream:
+ return self._handle_claude_stream_response(model, credentials, response, prompt_messages)
+
+ return self._handle_claude_response(model, credentials, response, prompt_messages)
+
+ def _handle_claude_response(self, model: str, credentials: dict, response: Message,
+ prompt_messages: list[PromptMessage]) -> LLMResult:
+ """
+ Handle llm chat response
+
+ :param model: model name
+ :param credentials: credentials
+ :param response: response
+ :param prompt_messages: prompt messages
+ :return: full response chunk generator result
+ """
+
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(
+ content=response.content[0].text
+ )
+
+ # calculate num tokens
+ if response.usage:
+ # transform usage
+ prompt_tokens = response.usage.input_tokens
+ completion_tokens = response.usage.output_tokens
+ else:
+ # calculate num tokens
+ prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+ completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ # transform response
+ response = LLMResult(
+ model=response.model,
+ prompt_messages=prompt_messages,
+ message=assistant_prompt_message,
+ usage=usage
+ )
+
+ return response
+
+ def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent],
+ prompt_messages: list[PromptMessage], ) -> Generator:
+ """
+ Handle llm chat stream response
+
+ :param model: model name
+ :param credentials: credentials
+ :param response: response
+ :param prompt_messages: prompt messages
+ :return: full response or stream response chunk generator result
+ """
+
+ try:
+ full_assistant_content = ''
+ return_model = None
+ input_tokens = 0
+ output_tokens = 0
+ finish_reason = None
+ index = 0
+
+ for chunk in response:
+ if isinstance(chunk, MessageStartEvent):
+ return_model = chunk.message.model
+ input_tokens = chunk.message.usage.input_tokens
+ elif isinstance(chunk, MessageDeltaEvent):
+ output_tokens = chunk.usage.output_tokens
+ finish_reason = chunk.delta.stop_reason
+ elif isinstance(chunk, MessageStopEvent):
+ usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
+ yield LLMResultChunk(
+ model=return_model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=index + 1,
+ message=AssistantPromptMessage(
+ content=''
+ ),
+ finish_reason=finish_reason,
+ usage=usage
+ )
+ )
+ elif isinstance(chunk, ContentBlockDeltaEvent):
+ chunk_text = chunk.delta.text if chunk.delta.text else ''
+ full_assistant_content += chunk_text
+ assistant_prompt_message = AssistantPromptMessage(
+ content=chunk_text if chunk_text else '',
+ )
+ index = chunk.index
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=index,
+ message=assistant_prompt_message,
+ )
+ )
+ except Exception as ex:
+ raise InvokeError(str(ex))
+
+ def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
+ """
+ Calculate response usage
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_tokens: prompt tokens
+ :param completion_tokens: completion tokens
+ :return: usage
+ """
+ # get prompt price info
+ prompt_price_info = self.get_price(
+ model=model,
+ credentials=credentials,
+ price_type=PriceType.INPUT,
+ tokens=prompt_tokens,
+ )
+
+ # get completion price info
+ completion_price_info = self.get_price(
+ model=model,
+ credentials=credentials,
+ price_type=PriceType.OUTPUT,
+ tokens=completion_tokens
+ )
+
+ # transform usage
+ usage = LLMUsage(
+ prompt_tokens=prompt_tokens,
+ prompt_unit_price=prompt_price_info.unit_price,
+ prompt_price_unit=prompt_price_info.unit,
+ prompt_price=prompt_price_info.total_amount,
+ completion_tokens=completion_tokens,
+ completion_unit_price=completion_price_info.unit_price,
+ completion_price_unit=completion_price_info.unit,
+ completion_price=completion_price_info.total_amount,
+ total_tokens=prompt_tokens + completion_tokens,
+ total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
+ currency=prompt_price_info.currency,
+ latency=time.perf_counter() - self.started_at
+ )
+
+ return usage
+
+ def _convert_claude_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
+ """
+ Convert prompt messages to dict list and system
+ """
+
+ system = ""
+ first_loop = True
+ for message in prompt_messages:
+ if isinstance(message, SystemPromptMessage):
+ message.content=message.content.strip()
+ if first_loop:
+ system=message.content
+ first_loop=False
+ else:
+ system+="\n"
+ system+=message.content
+
+ prompt_message_dicts = []
+ for message in prompt_messages:
+ if not isinstance(message, SystemPromptMessage):
+ prompt_message_dicts.append(self._convert_claude_prompt_message_to_dict(message))
+
+ return system, prompt_message_dicts
+
+ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+ """
+ Convert PromptMessage to dict
+ """
+ if isinstance(message, UserPromptMessage):
+ message = cast(UserPromptMessage, message)
+ if isinstance(message.content, str):
+ message_dict = {"role": "user", "content": message.content}
+ else:
+ sub_messages = []
+ for message_content in message.content:
+ if message_content.type == PromptMessageContentType.TEXT:
+ message_content = cast(TextPromptMessageContent, message_content)
+ sub_message_dict = {
+ "type": "text",
+ "text": message_content.data
+ }
+ sub_messages.append(sub_message_dict)
+ elif message_content.type == PromptMessageContentType.IMAGE:
+ message_content = cast(ImagePromptMessageContent, message_content)
+ if not message_content.data.startswith("data:"):
+ # fetch image data from url
+ try:
+ image_content = requests.get(message_content.data).content
+ mime_type, _ = mimetypes.guess_type(message_content.data)
+ base64_data = base64.b64encode(image_content).decode('utf-8')
+ except Exception as ex:
+ raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
+ else:
+ data_split = message_content.data.split(";base64,")
+ mime_type = data_split[0].replace("data:", "")
+ base64_data = data_split[1]
+
+ if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
+ raise ValueError(f"Unsupported image type {mime_type}, "
+ f"only support image/jpeg, image/png, image/gif, and image/webp")
+
+ sub_message_dict = {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": mime_type,
+ "data": base64_data
+ }
+ }
+ sub_messages.append(sub_message_dict)
+
+ message_dict = {"role": "user", "content": sub_messages}
+ elif isinstance(message, AssistantPromptMessage):
+ message = cast(AssistantPromptMessage, message)
+ message_dict = {"role": "assistant", "content": message.content}
+ elif isinstance(message, SystemPromptMessage):
+ message = cast(SystemPromptMessage, message)
+ message_dict = {"role": "system", "content": message.content}
+ else:
+ raise ValueError(f"Got unknown type {message}")
+
+ return message_dict
+
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param tools: tools for tool calling
+ :return:md = gml.GenerativeModel(model)
+ """
+ prompt = self._convert_messages_to_prompt(prompt_messages)
+
+ return self._get_num_tokens_by_gpt2(prompt)
+
+ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
+ """
+ Format a list of messages into a full prompt for the Google model
+
+ :param messages: List of PromptMessage to combine.
+ :return: Combined string with necessary human_prompt and ai_prompt tags.
+ """
+ messages = messages.copy() # don't mutate the original list
+
+ text = "".join(
+ self._convert_one_message_to_text(message)
+ for message in messages
+ )
+
+ return text.rstrip()
+
+ def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
+ """
+ Convert tool messages to glm tools
+
+ :param tools: tool messages
+ :return: glm tools
+ """
+ return glm.Tool(
+ function_declarations=[
+ glm.FunctionDeclaration(
+ name=tool.name,
+ parameters=glm.Schema(
+ type=glm.Type.OBJECT,
+ properties={
+ key: {
+ 'type_': value.get('type', 'string').upper(),
+ 'description': value.get('description', ''),
+ 'enum': value.get('enum', [])
+ } for key, value in tool.parameters.get('properties', {}).items()
+ },
+ required=tool.parameters.get('required', [])
+ ),
+ ) for tool in tools
+ ]
+ )
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+
+ try:
+ ping_message = SystemPromptMessage(content="ping")
+ self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
+
+ except Exception as ex:
+ raise CredentialsValidateFailedError(str(ex))
+
+
+ def _generate(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ stream: bool = True, user: Optional[str] = None
+ ) -> Union[LLMResult, Generator]:
+ """
+ Invoke large language model
+
+ :param model: model name
+ :param credentials: credentials kwargs
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+ config_kwargs = model_parameters.copy()
+ config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
+
+ if stop:
+ config_kwargs["stop_sequences"] = stop
+
+ service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+ project_id = credentials["vertex_project_id"]
+ location = credentials["vertex_location"]
+ if service_account_info:
+ service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+ aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+ else:
+ aiplatform.init(project=project_id, location=location)
+
+ history = []
+ system_instruction = GEMINI_BLOCK_MODE_PROMPT
+ # hack for gemini-pro-vision, which currently does not support multi-turn chat
+ if model == "gemini-1.0-pro-vision-001":
+ last_msg = prompt_messages[-1]
+ content = self._format_message_to_glm_content(last_msg)
+ history.append(content)
+ else:
+ for msg in prompt_messages:
+ if isinstance(msg, SystemPromptMessage):
+ system_instruction = msg.content
+ else:
+ content = self._format_message_to_glm_content(msg)
+ if history and history[-1].role == content.role:
+ history[-1].parts.extend(content.parts)
+ else:
+ history.append(content)
+
+ safety_settings={
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
+ }
+
+ google_model = glm.GenerativeModel(
+ model_name=model,
+ system_instruction=system_instruction
+ )
+
+ response = google_model.generate_content(
+ contents=history,
+ generation_config=glm.GenerationConfig(
+ **config_kwargs
+ ),
+ stream=stream,
+ safety_settings=safety_settings,
+ tools=self._convert_tools_to_glm_tool(tools) if tools else None
+ )
+
+ if stream:
+ return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
+
+ return self._handle_generate_response(model, credentials, response, prompt_messages)
+
+ def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
+ prompt_messages: list[PromptMessage]) -> LLMResult:
+ """
+ Handle llm response
+
+ :param model: model name
+ :param credentials: credentials
+ :param response: response
+ :param prompt_messages: prompt messages
+ :return: llm response
+ """
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(
+ content=response.candidates[0].content.parts[0].text
+ )
+
+ # calculate num tokens
+ prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+ completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ # transform response
+ result = LLMResult(
+ model=model,
+ prompt_messages=prompt_messages,
+ message=assistant_prompt_message,
+ usage=usage,
+ )
+
+ return result
+
+ def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
+ prompt_messages: list[PromptMessage]) -> Generator:
+ """
+ Handle llm stream response
+
+ :param model: model name
+ :param credentials: credentials
+ :param response: response
+ :param prompt_messages: prompt messages
+ :return: llm response chunk generator result
+ """
+ index = -1
+ for chunk in response:
+ for part in chunk.candidates[0].content.parts:
+ assistant_prompt_message = AssistantPromptMessage(
+ content=''
+ )
+
+ if part.text:
+ assistant_prompt_message.content += part.text
+
+ if part.function_call:
+ assistant_prompt_message.tool_calls = [
+ AssistantPromptMessage.ToolCall(
+ id=part.function_call.name,
+ type='function',
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+ name=part.function_call.name,
+ arguments=json.dumps({
+ key: value
+ for key, value in part.function_call.args.items()
+ })
+ )
+ )
+ ]
+
+ index += 1
+
+ if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason:
+ # transform assistant message to prompt message
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=index,
+ message=assistant_prompt_message
+ )
+ )
+ else:
+
+ # calculate num tokens
+ prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+ completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=index,
+ message=assistant_prompt_message,
+ finish_reason=chunk.candidates[0].finish_reason,
+ usage=usage
+ )
+ )
+
+ def _convert_one_message_to_text(self, message: PromptMessage) -> str:
+ """
+ Convert a single message to a string.
+
+ :param message: PromptMessage to convert.
+ :return: String representation of the message.
+ """
+ human_prompt = "\n\nuser:"
+ ai_prompt = "\n\nmodel:"
+
+ content = message.content
+ if isinstance(content, list):
+ content = "".join(
+ c.data for c in content if c.type != PromptMessageContentType.IMAGE
+ )
+
+ if isinstance(message, UserPromptMessage):
+ message_text = f"{human_prompt} {content}"
+ elif isinstance(message, AssistantPromptMessage):
+ message_text = f"{ai_prompt} {content}"
+ elif isinstance(message, SystemPromptMessage):
+ message_text = f"{human_prompt} {content}"
+ elif isinstance(message, ToolPromptMessage):
+ message_text = f"{human_prompt} {content}"
+ else:
+ raise ValueError(f"Got unknown type {message}")
+
+ return message_text
+
+ def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
+ """
+ Format a single message into glm.Content for Google API
+
+ :param message: one PromptMessage
+ :return: glm Content representation of message
+ """
+ if isinstance(message, UserPromptMessage):
+ glm_content = glm.Content(role="user", parts=[])
+
+ if (isinstance(message.content, str)):
+ glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)])
+ else:
+ parts = []
+ for c in message.content:
+ if c.type == PromptMessageContentType.TEXT:
+ parts.append(glm.Part.from_text(c.data))
+ else:
+ metadata, data = c.data.split(',', 1)
+ mime_type = metadata.split(';', 1)[0].split(':')[1]
+ parts.append(glm.Part.from_data(mime_type=mime_type, data=data))
+ glm_content = glm.Content(role="user", parts=parts)
+ return glm_content
+ elif isinstance(message, AssistantPromptMessage):
+ if message.content:
+ glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)])
+ if message.tool_calls:
+ glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall(
+ name=message.tool_calls[0].function.name,
+ args=json.loads(message.tool_calls[0].function.arguments),
+ ))])
+ return glm_content
+ elif isinstance(message, ToolPromptMessage):
+ glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse(
+ name=message.name,
+ response={
+ "response": message.content
+ }
+ ))])
+ return glm_content
+ else:
+ raise ValueError(f"Got unknown type {message}")
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller
+ The value is the md = gml.GenerativeModel(model)error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke emd = gml.GenerativeModel(model)rror mapping
+ """
+ return {
+ InvokeConnectionError: [
+ exceptions.RetryError
+ ],
+ InvokeServerUnavailableError: [
+ exceptions.ServiceUnavailable,
+ exceptions.InternalServerError,
+ exceptions.BadGateway,
+ exceptions.GatewayTimeout,
+ exceptions.DeadlineExceeded
+ ],
+ InvokeRateLimitError: [
+ exceptions.ResourceExhausted,
+ exceptions.TooManyRequests
+ ],
+ InvokeAuthorizationError: [
+ exceptions.Unauthenticated,
+ exceptions.PermissionDenied,
+ exceptions.Unauthenticated,
+ exceptions.Forbidden
+ ],
+ InvokeBadRequestError: [
+ exceptions.BadRequest,
+ exceptions.InvalidArgument,
+ exceptions.FailedPrecondition,
+ exceptions.OutOfRange,
+ exceptions.NotFound,
+ exceptions.MethodNotAllowed,
+ exceptions.Conflict,
+ exceptions.AlreadyExists,
+ exceptions.Aborted,
+ exceptions.LengthRequired,
+ exceptions.PreconditionFailed,
+ exceptions.RequestRangeNotSatisfiable,
+ exceptions.Cancelled,
+ ]
+ }
diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml
new file mode 100644
index 0000000000..32db6faf89
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml
@@ -0,0 +1,8 @@
+model: text-embedding-004
+model_type: text-embedding
+model_properties:
+ context_size: 2048
+pricing:
+ input: '0.00013'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml
new file mode 100644
index 0000000000..2ec0eea9f2
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml
@@ -0,0 +1,8 @@
+model: text-multilingual-embedding-002
+model_type: text-embedding
+model_properties:
+ context_size: 2048
+pricing:
+ input: '0.00013'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py
new file mode 100644
index 0000000000..2404ba5894
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py
@@ -0,0 +1,197 @@
+import base64
+import json
+import time
+from decimal import Decimal
+from typing import Optional
+
+import tiktoken
+from google.cloud import aiplatform
+from google.oauth2 import service_account
+from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import (
+ AIModelEntity,
+ FetchFrom,
+ ModelPropertyKey,
+ ModelType,
+ PriceConfig,
+ PriceType,
+)
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
+
+
+class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
+ """
+ Model class for Vertex AI text embedding model.
+ """
+
+ def _invoke(self, model: str, credentials: dict,
+ texts: list[str], user: Optional[str] = None) \
+ -> TextEmbeddingResult:
+ """
+ Invoke text embedding model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :return: embeddings result
+ """
+ service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+ project_id = credentials["vertex_project_id"]
+ location = credentials["vertex_location"]
+ if service_account_info:
+ service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+ aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+ else:
+ aiplatform.init(project=project_id, location=location)
+
+ client = VertexTextEmbeddingModel.from_pretrained(model)
+
+ embeddings_batch, embedding_used_tokens = self._embedding_invoke(
+ client=client,
+ texts=texts
+ )
+
+ # calc usage
+ usage = self._calc_response_usage(
+ model=model,
+ credentials=credentials,
+ tokens=embedding_used_tokens
+ )
+
+ return TextEmbeddingResult(
+ embeddings=embeddings_batch,
+ usage=usage,
+ model=model
+ )
+
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :return:
+ """
+ if len(texts) == 0:
+ return 0
+
+ try:
+ enc = tiktoken.encoding_for_model(model)
+ except KeyError:
+ enc = tiktoken.get_encoding("cl100k_base")
+
+ total_num_tokens = 0
+ for text in texts:
+ # calculate the number of tokens in the encoded text
+ tokenized_text = enc.encode(text)
+ total_num_tokens += len(tokenized_text)
+
+ return total_num_tokens
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+ project_id = credentials["vertex_project_id"]
+ location = credentials["vertex_location"]
+ if service_account_info:
+ service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+ aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+ else:
+ aiplatform.init(project=project_id, location=location)
+
+ client = VertexTextEmbeddingModel.from_pretrained(model)
+
+ # call embedding model
+ self._embedding_invoke(
+ model=model,
+ client=client,
+ texts=['ping']
+ )
+ except Exception as ex:
+ raise CredentialsValidateFailedError(str(ex))
+
+ def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore
+ """
+ Invoke embedding model
+
+ :param model: model name
+ :param client: model client
+ :param texts: texts to embed
+ :return: embeddings and used tokens
+ """
+ response = client.get_embeddings(texts)
+
+ embeddings = []
+ token_usage = 0
+
+ for i in range(len(response)):
+ embeddings.append(response[i].values)
+ token_usage += int(response[i].statistics.token_count)
+
+ return embeddings, token_usage
+
+ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+ """
+ Calculate response usage
+
+ :param model: model name
+ :param credentials: model credentials
+ :param tokens: input tokens
+ :return: usage
+ """
+ # get input price info
+ input_price_info = self.get_price(
+ model=model,
+ credentials=credentials,
+ price_type=PriceType.INPUT,
+ tokens=tokens
+ )
+
+ # transform usage
+ usage = EmbeddingUsage(
+ tokens=tokens,
+ total_tokens=tokens,
+ unit_price=input_price_info.unit_price,
+ price_unit=input_price_info.unit,
+ total_price=input_price_info.total_amount,
+ currency=input_price_info.currency,
+ latency=time.perf_counter() - self.started_at
+ )
+
+ return usage
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.TEXT_EMBEDDING,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
+ ModelPropertyKey.MAX_CHUNKS: 1,
+ },
+ parameter_rules=[],
+ pricing=PriceConfig(
+ input=Decimal(credentials.get('input_price', 0)),
+ unit=Decimal(credentials.get('unit', 0)),
+ currency=credentials.get('currency', "USD")
+ )
+ )
+
+ return entity
diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py
new file mode 100644
index 0000000000..3cbfb088d1
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py
@@ -0,0 +1,31 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class VertexAiProvider(ModelProvider):
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ """
+ Validate provider credentials
+
+ if validate failed, raise exception
+
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+ """
+ try:
+ model_instance = self.get_model_instance(ModelType.LLM)
+
+ # Use `gemini-1.0-pro-002` model for validate,
+ model_instance.validate_credentials(
+ model='gemini-1.0-pro-002',
+ credentials=credentials
+ )
+ except CredentialsValidateFailedError as ex:
+ raise ex
+ except Exception as ex:
+ logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+ raise ex
diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml
new file mode 100644
index 0000000000..27a4d03fe2
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml
@@ -0,0 +1,43 @@
+provider: vertex_ai
+label:
+ en_US: Vertex AI | Google Cloud Platform
+description:
+ en_US: Vertex AI in Google Cloud Platform.
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.png
+background: "#FCFDFF"
+help:
+ title:
+ en_US: Get your Access Details from Google
+ url:
+ en_US: https://cloud.google.com/vertex-ai/
+supported_model_types:
+ - llm
+ - text-embedding
+configurate_methods:
+ - predefined-model
+provider_credential_schema:
+ credential_form_schemas:
+ - variable: vertex_project_id
+ label:
+ en_US: Project ID
+ type: text-input
+ required: true
+ placeholder:
+ en_US: Enter your Google Cloud Project ID
+ - variable: vertex_location
+ label:
+ en_US: Location
+ type: text-input
+ required: true
+ placeholder:
+ en_US: Enter your Google Cloud Location
+ - variable: vertex_service_account_key
+ label:
+ en_US: Service Account Key (Leave blank if you use Application Default Credentials)
+ type: secret-input
+ required: false
+ placeholder:
+ en_US: Enter your Google Cloud Service Account Key in base64 format
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg
new file mode 100644
index 0000000000..616e90916b
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg
@@ -0,0 +1,23 @@
+
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg
new file mode 100644
index 0000000000..24b92195bd
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg
@@ -0,0 +1,39 @@
+
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg
new file mode 100644
index 0000000000..e6454a89b7
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg
@@ -0,0 +1,8 @@
+
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py
new file mode 100644
index 0000000000..c7bf4fde8c
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py
@@ -0,0 +1,108 @@
+import re
+from collections.abc import Callable, Generator
+from typing import cast
+
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ ImagePromptMessageContent,
+ PromptMessage,
+ PromptMessageContentType,
+ SystemPromptMessage,
+ UserPromptMessage,
+)
+from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService
+
+
+class MaaSClient(MaasService):
+ def __init__(self, host: str, region: str):
+ self.endpoint_id = None
+ super().__init__(host, region)
+
+ def set_endpoint_id(self, endpoint_id: str):
+ self.endpoint_id = endpoint_id
+
+ @classmethod
+ def from_credential(cls, credentials: dict) -> 'MaaSClient':
+ host = credentials['api_endpoint_host']
+ region = credentials['volc_region']
+ ak = credentials['volc_access_key_id']
+ sk = credentials['volc_secret_access_key']
+ endpoint_id = credentials['endpoint_id']
+
+ client = cls(host, region)
+ client.set_endpoint_id(endpoint_id)
+ client.set_ak(ak)
+ client.set_sk(sk)
+ return client
+
+ def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict:
+ req = {
+ 'parameters': params,
+ 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages]
+ }
+ if not stream:
+ return super().chat(
+ self.endpoint_id,
+ req,
+ )
+ return super().stream_chat(
+ self.endpoint_id,
+ req,
+ )
+
+ def embeddings(self, texts: list[str]) -> dict:
+ req = {
+ 'input': texts
+ }
+ return super().embeddings(self.endpoint_id, req)
+
+ @staticmethod
+ def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
+ if isinstance(message, UserPromptMessage):
+ message = cast(UserPromptMessage, message)
+ if isinstance(message.content, str):
+ message_dict = {"role": ChatRole.USER,
+ "content": message.content}
+ else:
+ content = []
+ for message_content in message.content:
+ if message_content.type == PromptMessageContentType.TEXT:
+ raise ValueError(
+ 'Content object type only support image_url')
+ elif message_content.type == PromptMessageContentType.IMAGE:
+ message_content = cast(
+ ImagePromptMessageContent, message_content)
+ image_data = re.sub(
+ r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
+ content.append({
+ 'type': 'image_url',
+ 'image_url': {
+ 'url': '',
+ 'image_bytes': image_data,
+ 'detail': message_content.detail,
+ }
+ })
+
+ message_dict = {'role': ChatRole.USER, 'content': content}
+ elif isinstance(message, AssistantPromptMessage):
+ message = cast(AssistantPromptMessage, message)
+ message_dict = {'role': ChatRole.ASSISTANT,
+ 'content': message.content}
+ elif isinstance(message, SystemPromptMessage):
+ message = cast(SystemPromptMessage, message)
+ message_dict = {'role': ChatRole.SYSTEM,
+ 'content': message.content}
+ else:
+ raise ValueError(f"Got unknown PromptMessage type {message}")
+
+ return message_dict
+
+ @staticmethod
+ def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
+ try:
+ resp = fn()
+ except MaasException as e:
+ raise wrap_error(e)
+
+ return resp
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/errors.py
new file mode 100644
index 0000000000..63397a456e
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/errors.py
@@ -0,0 +1,156 @@
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+
+class ClientSDKRequestError(MaasException):
+ pass
+
+
+class SignatureDoesNotMatch(MaasException):
+ pass
+
+
+class RequestTimeout(MaasException):
+ pass
+
+
+class ServiceConnectionTimeout(MaasException):
+ pass
+
+
+class MissingAuthenticationHeader(MaasException):
+ pass
+
+
+class AuthenticationHeaderIsInvalid(MaasException):
+ pass
+
+
+class InternalServiceError(MaasException):
+ pass
+
+
+class MissingParameter(MaasException):
+ pass
+
+
+class InvalidParameter(MaasException):
+ pass
+
+
+class AuthenticationExpire(MaasException):
+ pass
+
+
+class EndpointIsInvalid(MaasException):
+ pass
+
+
+class EndpointIsNotEnable(MaasException):
+ pass
+
+
+class ModelNotSupportStreamMode(MaasException):
+ pass
+
+
+class ReqTextExistRisk(MaasException):
+ pass
+
+
+class RespTextExistRisk(MaasException):
+ pass
+
+
+class EndpointRateLimitExceeded(MaasException):
+ pass
+
+
+class ServiceConnectionRefused(MaasException):
+ pass
+
+
+class ServiceConnectionClosed(MaasException):
+ pass
+
+
+class UnauthorizedUserForEndpoint(MaasException):
+ pass
+
+
+class InvalidEndpointWithNoURL(MaasException):
+ pass
+
+
+class EndpointAccountRpmRateLimitExceeded(MaasException):
+ pass
+
+
+class EndpointAccountTpmRateLimitExceeded(MaasException):
+ pass
+
+
+class ServiceResourceWaitQueueFull(MaasException):
+ pass
+
+
+class EndpointIsPending(MaasException):
+ pass
+
+
+class ServiceNotOpen(MaasException):
+ pass
+
+
+AuthErrors = {
+ 'SignatureDoesNotMatch': SignatureDoesNotMatch,
+ 'MissingAuthenticationHeader': MissingAuthenticationHeader,
+ 'AuthenticationHeaderIsInvalid': AuthenticationHeaderIsInvalid,
+ 'AuthenticationExpire': AuthenticationExpire,
+ 'UnauthorizedUserForEndpoint': UnauthorizedUserForEndpoint,
+}
+
+BadRequestErrors = {
+ 'MissingParameter': MissingParameter,
+ 'InvalidParameter': InvalidParameter,
+ 'EndpointIsInvalid': EndpointIsInvalid,
+ 'EndpointIsNotEnable': EndpointIsNotEnable,
+ 'ModelNotSupportStreamMode': ModelNotSupportStreamMode,
+ 'ReqTextExistRisk': ReqTextExistRisk,
+ 'RespTextExistRisk': RespTextExistRisk,
+ 'InvalidEndpointWithNoURL': InvalidEndpointWithNoURL,
+ 'ServiceNotOpen': ServiceNotOpen,
+}
+
+RateLimitErrors = {
+ 'EndpointRateLimitExceeded': EndpointRateLimitExceeded,
+ 'EndpointAccountRpmRateLimitExceeded': EndpointAccountRpmRateLimitExceeded,
+ 'EndpointAccountTpmRateLimitExceeded': EndpointAccountTpmRateLimitExceeded,
+}
+
+ServerUnavailableErrors = {
+ 'InternalServiceError': InternalServiceError,
+ 'EndpointIsPending': EndpointIsPending,
+ 'ServiceResourceWaitQueueFull': ServiceResourceWaitQueueFull,
+}
+
+ConnectionErrors = {
+ 'ClientSDKRequestError': ClientSDKRequestError,
+ 'RequestTimeout': RequestTimeout,
+ 'ServiceConnectionTimeout': ServiceConnectionTimeout,
+ 'ServiceConnectionRefused': ServiceConnectionRefused,
+ 'ServiceConnectionClosed': ServiceConnectionClosed,
+}
+
+ErrorCodeMap = {
+ **AuthErrors,
+ **BadRequestErrors,
+ **RateLimitErrors,
+ **ServerUnavailableErrors,
+ **ConnectionErrors,
+}
+
+
+def wrap_error(e: MaasException) -> Exception:
+ if ErrorCodeMap.get(e.code):
+ return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
+ return e
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py
new file mode 100644
index 0000000000..7a36d019e2
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py
@@ -0,0 +1,284 @@
+import logging
+from collections.abc import Generator
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ PromptMessage,
+ PromptMessageTool,
+ UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import (
+ AIModelEntity,
+ FetchFrom,
+ ModelPropertyKey,
+ ModelType,
+ ParameterRule,
+ ParameterType,
+)
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
+from core.model_runtime.model_providers.volcengine_maas.errors import (
+ AuthErrors,
+ BadRequestErrors,
+ ConnectionErrors,
+ RateLimitErrors,
+ ServerUnavailableErrors,
+)
+from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+logger = logging.getLogger(__name__)
+
+
+class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
+ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+ -> LLMResult | Generator:
+ return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate credentials
+ """
+ # ping
+ client = MaaSClient.from_credential(credentials)
+ try:
+ client.chat(
+ {
+ 'max_new_tokens': 16,
+ 'temperature': 0.7,
+ 'top_p': 0.9,
+ 'top_k': 15,
+ },
+ [UserPromptMessage(content='ping\nAnswer: ')],
+ )
+ except MaasException as e:
+ raise CredentialsValidateFailedError(e.message)
+
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: list[PromptMessageTool] | None = None) -> int:
+ if len(prompt_messages) == 0:
+ return 0
+ return self._num_tokens_from_messages(prompt_messages)
+
+ def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
+ """
+ Calculate num tokens.
+
+ :param messages: messages
+ """
+ num_tokens = 0
+ messages_dict = [
+ MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
+ for message in messages_dict:
+ for key, value in message.items():
+ num_tokens += self._get_num_tokens_by_gpt2(str(key))
+ num_tokens += self._get_num_tokens_by_gpt2(str(value))
+
+ return num_tokens
+
+ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+ -> LLMResult | Generator:
+
+ client = MaaSClient.from_credential(credentials)
+
+ req_params = ModelConfigs.get(
+ credentials['base_model_name'], {}).get('req_params', {}).copy()
+ if credentials.get('context_size'):
+ req_params['max_prompt_tokens'] = credentials.get('context_size')
+ if credentials.get('max_tokens'):
+ req_params['max_new_tokens'] = credentials.get('max_tokens')
+ if model_parameters.get('max_tokens'):
+ req_params['max_new_tokens'] = model_parameters.get('max_tokens')
+ if model_parameters.get('temperature'):
+ req_params['temperature'] = model_parameters.get('temperature')
+ if model_parameters.get('top_p'):
+ req_params['top_p'] = model_parameters.get('top_p')
+ if model_parameters.get('top_k'):
+ req_params['top_k'] = model_parameters.get('top_k')
+ if model_parameters.get('presence_penalty'):
+ req_params['presence_penalty'] = model_parameters.get(
+ 'presence_penalty')
+ if model_parameters.get('frequency_penalty'):
+ req_params['frequency_penalty'] = model_parameters.get(
+ 'frequency_penalty')
+ if stop:
+ req_params['stop'] = stop
+
+ resp = MaaSClient.wrap_exception(
+ lambda: client.chat(req_params, prompt_messages, stream))
+ if not stream:
+ return self._handle_chat_response(model, credentials, prompt_messages, resp)
+ return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
+
+ def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator:
+ for index, r in enumerate(resp):
+ choices = r['choices']
+ if not choices:
+ continue
+ choice = choices[0]
+ message = choice['message']
+ usage = None
+ if r.get('usage'):
+ usage = self._calc_usage(model, credentials, r['usage'])
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=index,
+ message=AssistantPromptMessage(
+ content=message['content'] if message['content'] else '',
+ tool_calls=[]
+ ),
+ usage=usage,
+ finish_reason=choice.get('finish_reason'),
+ ),
+ )
+
+ def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult:
+ choices = resp['choices']
+ if not choices:
+ return
+ choice = choices[0]
+ message = choice['message']
+
+ return LLMResult(
+ model=model,
+ prompt_messages=prompt_messages,
+ message=AssistantPromptMessage(
+ content=message['content'] if message['content'] else '',
+ tool_calls=[],
+ ),
+ usage=self._calc_usage(model, credentials, resp['usage']),
+ )
+
+ def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage:
+ return self._calc_response_usage(model=model, credentials=credentials,
+ prompt_tokens=usage['prompt_tokens'],
+ completion_tokens=usage['completion_tokens']
+ )
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+ """
+ used to define customizable model schema
+ """
+ max_tokens = ModelConfigs.get(
+ credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
+ if credentials.get('max_tokens'):
+ max_tokens = int(credentials.get('max_tokens'))
+ rules = [
+ ParameterRule(
+ name='temperature',
+ type=ParameterType.FLOAT,
+ use_template='temperature',
+ label=I18nObject(
+ zh_Hans='温度',
+ en_US='Temperature'
+ )
+ ),
+ ParameterRule(
+ name='top_p',
+ type=ParameterType.FLOAT,
+ use_template='top_p',
+ label=I18nObject(
+ zh_Hans='Top P',
+ en_US='Top P'
+ )
+ ),
+ ParameterRule(
+ name='top_k',
+ type=ParameterType.INT,
+ min=1,
+ default=1,
+ label=I18nObject(
+ zh_Hans='Top K',
+ en_US='Top K'
+ )
+ ),
+ ParameterRule(
+ name='presence_penalty',
+ type=ParameterType.FLOAT,
+ use_template='presence_penalty',
+ label={
+ 'en_US': 'Presence Penalty',
+ 'zh_Hans': '存在惩罚',
+ },
+ min=-2.0,
+ max=2.0,
+ ),
+ ParameterRule(
+ name='frequency_penalty',
+ type=ParameterType.FLOAT,
+ use_template='frequency_penalty',
+ label={
+ 'en_US': 'Frequency Penalty',
+ 'zh_Hans': '频率惩罚',
+ },
+ min=-2.0,
+ max=2.0,
+ ),
+ ParameterRule(
+ name='max_tokens',
+ type=ParameterType.INT,
+ use_template='max_tokens',
+ min=1,
+ max=max_tokens,
+ default=512,
+ label=I18nObject(
+ zh_Hans='最大生成长度',
+ en_US='Max Tokens'
+ )
+ ),
+ ]
+
+ model_properties = ModelConfigs.get(
+ credentials['base_model_name'], {}).get('model_properties', {}).copy()
+ if credentials.get('mode'):
+ model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
+ if credentials.get('context_size'):
+ model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
+ credentials.get('context_size', 4096))
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(
+ en_US=model
+ ),
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_type=ModelType.LLM,
+ model_properties=model_properties,
+ parameter_rules=rules
+ )
+
+ return entity
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ return {
+ InvokeConnectionError: ConnectionErrors.values(),
+ InvokeServerUnavailableError: ServerUnavailableErrors.values(),
+ InvokeRateLimitError: RateLimitErrors.values(),
+ InvokeAuthorizationError: AuthErrors.values(),
+ InvokeBadRequestError: BadRequestErrors.values(),
+ }
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py
new file mode 100644
index 0000000000..2e8ff314fc
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py
@@ -0,0 +1,72 @@
+ModelConfigs = {
+ 'Doubao-pro-4k': {
+ 'req_params': {
+ 'max_prompt_tokens': 4096,
+ 'max_new_tokens': 4096,
+ },
+ 'model_properties': {
+ 'context_size': 4096,
+ 'mode': 'chat',
+ }
+ },
+ 'Doubao-lite-4k': {
+ 'req_params': {
+ 'max_prompt_tokens': 4096,
+ 'max_new_tokens': 4096,
+ },
+ 'model_properties': {
+ 'context_size': 4096,
+ 'mode': 'chat',
+ }
+ },
+ 'Doubao-pro-32k': {
+ 'req_params': {
+ 'max_prompt_tokens': 32768,
+ 'max_new_tokens': 32768,
+ },
+ 'model_properties': {
+ 'context_size': 32768,
+ 'mode': 'chat',
+ }
+ },
+ 'Doubao-lite-32k': {
+ 'req_params': {
+ 'max_prompt_tokens': 32768,
+ 'max_new_tokens': 32768,
+ },
+ 'model_properties': {
+ 'context_size': 32768,
+ 'mode': 'chat',
+ }
+ },
+ 'Doubao-pro-128k': {
+ 'req_params': {
+ 'max_prompt_tokens': 131072,
+ 'max_new_tokens': 131072,
+ },
+ 'model_properties': {
+ 'context_size': 131072,
+ 'mode': 'chat',
+ }
+ },
+ 'Doubao-lite-128k': {
+ 'req_params': {
+ 'max_prompt_tokens': 131072,
+ 'max_new_tokens': 131072,
+ },
+ 'model_properties': {
+ 'context_size': 131072,
+ 'mode': 'chat',
+ }
+ },
+ 'Skylark2-pro-4k': {
+ 'req_params': {
+ 'max_prompt_tokens': 4096,
+ 'max_new_tokens': 4000,
+ },
+ 'model_properties': {
+ 'context_size': 4096,
+ 'mode': 'chat',
+ }
+ },
+}
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py
new file mode 100644
index 0000000000..569f89e975
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py
@@ -0,0 +1,9 @@
+ModelConfigs = {
+ 'Doubao-embedding': {
+ 'req_params': {},
+ 'model_properties': {
+ 'context_size': 4096,
+ 'max_chunks': 1,
+ }
+ },
+}
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py
new file mode 100644
index 0000000000..10b01c0d0d
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py
@@ -0,0 +1,170 @@
+import time
+from decimal import Decimal
+from typing import Optional
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import (
+ AIModelEntity,
+ FetchFrom,
+ ModelPropertyKey,
+ ModelType,
+ PriceConfig,
+ PriceType,
+)
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
+from core.model_runtime.model_providers.volcengine_maas.errors import (
+ AuthErrors,
+ BadRequestErrors,
+ ConnectionErrors,
+ RateLimitErrors,
+ ServerUnavailableErrors,
+)
+from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+
+class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
+ """
+ Model class for VolcengineMaaS text embedding model.
+ """
+
+ def _invoke(self, model: str, credentials: dict,
+ texts: list[str], user: Optional[str] = None) \
+ -> TextEmbeddingResult:
+ """
+ Invoke text embedding model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :param user: unique user id
+ :return: embeddings result
+ """
+ client = MaaSClient.from_credential(credentials)
+ resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
+
+ usage = self._calc_response_usage(
+ model=model, credentials=credentials, tokens=resp['usage']['total_tokens'])
+
+ result = TextEmbeddingResult(
+ model=model,
+ embeddings=[v['embedding'] for v in resp['data']],
+ usage=usage
+ )
+
+ return result
+
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :return:
+ """
+ num_tokens = 0
+ for text in texts:
+ # use GPT2Tokenizer to get num tokens
+ num_tokens += self._get_num_tokens_by_gpt2(text)
+ return num_tokens
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ self._invoke(model=model, credentials=credentials, texts=['ping'])
+ except MaasException as e:
+ raise CredentialsValidateFailedError(e.message)
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ return {
+ InvokeConnectionError: ConnectionErrors.values(),
+ InvokeServerUnavailableError: ServerUnavailableErrors.values(),
+ InvokeRateLimitError: RateLimitErrors.values(),
+ InvokeAuthorizationError: AuthErrors.values(),
+ InvokeBadRequestError: BadRequestErrors.values(),
+ }
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ model_properties = ModelConfigs.get(
+ credentials['base_model_name'], {}).get('model_properties', {}).copy()
+ if credentials.get('context_size'):
+ model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
+ credentials.get('context_size', 4096))
+ if credentials.get('max_chunks'):
+ model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
+ credentials.get('max_chunks', 4096))
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.TEXT_EMBEDDING,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties=model_properties,
+ parameter_rules=[],
+ pricing=PriceConfig(
+ input=Decimal(credentials.get('input_price', 0)),
+ unit=Decimal(credentials.get('unit', 0)),
+ currency=credentials.get('currency', "USD")
+ )
+ )
+
+ return entity
+
+ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+ """
+ Calculate response usage
+
+ :param model: model name
+ :param credentials: model credentials
+ :param tokens: input tokens
+ :return: usage
+ """
+ # get input price info
+ input_price_info = self.get_price(
+ model=model,
+ credentials=credentials,
+ price_type=PriceType.INPUT,
+ tokens=tokens
+ )
+
+ # transform usage
+ usage = EmbeddingUsage(
+ tokens=tokens,
+ total_tokens=tokens,
+ unit_price=input_price_info.unit_price,
+ price_unit=input_price_info.unit,
+ total_price=input_price_info.total_amount,
+ currency=input_price_info.currency,
+ latency=time.perf_counter() - self.started_at
+ )
+
+ return usage
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py
new file mode 100644
index 0000000000..64f342f16e
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py
@@ -0,0 +1,4 @@
+from .common import ChatRole
+from .maas import MaasException, MaasService
+
+__all__ = ['MaasService', 'ChatRole', 'MaasException']
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py
new file mode 100644
index 0000000000..8b13789179
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py
@@ -0,0 +1 @@
+
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py
new file mode 100644
index 0000000000..48110f16d7
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py
@@ -0,0 +1,144 @@
+# coding : utf-8
+import datetime
+
+import pytz
+
+from .util import Util
+
+
+class MetaData:
+ def __init__(self):
+ self.algorithm = ''
+ self.credential_scope = ''
+ self.signed_headers = ''
+ self.date = ''
+ self.region = ''
+ self.service = ''
+
+ def set_date(self, date):
+ self.date = date
+
+ def set_service(self, service):
+ self.service = service
+
+ def set_region(self, region):
+ self.region = region
+
+ def set_algorithm(self, algorithm):
+ self.algorithm = algorithm
+
+ def set_credential_scope(self, credential_scope):
+ self.credential_scope = credential_scope
+
+ def set_signed_headers(self, signed_headers):
+ self.signed_headers = signed_headers
+
+
+class SignResult:
+ def __init__(self):
+ self.xdate = ''
+ self.xCredential = ''
+ self.xAlgorithm = ''
+ self.xSignedHeaders = ''
+ self.xSignedQueries = ''
+ self.xSignature = ''
+ self.xContextSha256 = ''
+ self.xSecurityToken = ''
+
+ self.authorization = ''
+
+ def __str__(self):
+ return '\n'.join(['{}:{}'.format(*item) for item in self.__dict__.items()])
+
+
+class Credentials:
+ def __init__(self, ak, sk, service, region, session_token=''):
+ self.ak = ak
+ self.sk = sk
+ self.service = service
+ self.region = region
+ self.session_token = session_token
+
+ def set_ak(self, ak):
+ self.ak = ak
+
+ def set_sk(self, sk):
+ self.sk = sk
+
+ def set_session_token(self, session_token):
+ self.session_token = session_token
+
+
+class Signer:
+ @staticmethod
+ def sign(request, credentials):
+ if request.path == '':
+ request.path = '/'
+ if request.method != 'GET' and not ('Content-Type' in request.headers):
+ request.headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8'
+
+ format_date = Signer.get_current_format_date()
+ request.headers['X-Date'] = format_date
+ if credentials.session_token != '':
+ request.headers['X-Security-Token'] = credentials.session_token
+
+ md = MetaData()
+ md.set_algorithm('HMAC-SHA256')
+ md.set_service(credentials.service)
+ md.set_region(credentials.region)
+ md.set_date(format_date[:8])
+
+ hashed_canon_req = Signer.hashed_canonical_request_v4(request, md)
+ md.set_credential_scope('/'.join([md.date, md.region, md.service, 'request']))
+
+ signing_str = '\n'.join([md.algorithm, format_date, md.credential_scope, hashed_canon_req])
+ signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
+ sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
+ request.headers['Authorization'] = Signer.build_auth_header_v4(sign, md, credentials)
+ return
+
+ @staticmethod
+ def hashed_canonical_request_v4(request, meta):
+ body_hash = Util.sha256(request.body)
+ request.headers['X-Content-Sha256'] = body_hash
+
+ signed_headers = dict()
+ for key in request.headers:
+ if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'):
+ signed_headers[key.lower()] = request.headers[key]
+
+ if 'host' in signed_headers:
+ v = signed_headers['host']
+ if v.find(':') != -1:
+ split = v.split(':')
+ port = split[1]
+ if str(port) == '80' or str(port) == '443':
+ signed_headers['host'] = split[0]
+
+ signed_str = ''
+ for key in sorted(signed_headers.keys()):
+ signed_str += key + ':' + signed_headers[key] + '\n'
+
+ meta.set_signed_headers(';'.join(sorted(signed_headers.keys())))
+
+ canonical_request = '\n'.join(
+ [request.method, Util.norm_uri(request.path), Util.norm_query(request.query), signed_str,
+ meta.signed_headers, body_hash])
+
+ return Util.sha256(canonical_request)
+
+ @staticmethod
+ def get_signing_secret_key_v4(sk, date, region, service):
+ date = Util.hmac_sha256(bytes(sk, encoding='utf-8'), date)
+ region = Util.hmac_sha256(date, region)
+ service = Util.hmac_sha256(region, service)
+ return Util.hmac_sha256(service, 'request')
+
+ @staticmethod
+ def build_auth_header_v4(signature, meta, credentials):
+ credential = credentials.ak + '/' + meta.credential_scope
+ return meta.algorithm + ' Credential=' + credential + ', SignedHeaders=' + meta.signed_headers + ', Signature=' + signature
+
+ @staticmethod
+ def get_current_format_date():
+ return datetime.datetime.now(tz=pytz.timezone('UTC')).strftime("%Y%m%dT%H%M%SZ")
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py
new file mode 100644
index 0000000000..03734ec54f
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py
@@ -0,0 +1,207 @@
+import json
+from collections import OrderedDict
+from urllib.parse import urlencode
+
+import requests
+
+from .auth import Signer
+
+VERSION = 'v1.0.137'
+
+
+class Service:
+ def __init__(self, service_info, api_info):
+ self.service_info = service_info
+ self.api_info = api_info
+ self.session = requests.session()
+
+ def set_ak(self, ak):
+ self.service_info.credentials.set_ak(ak)
+
+ def set_sk(self, sk):
+ self.service_info.credentials.set_sk(sk)
+
+ def set_session_token(self, session_token):
+ self.service_info.credentials.set_session_token(session_token)
+
+ def set_host(self, host):
+ self.service_info.host = host
+
+ def set_scheme(self, scheme):
+ self.service_info.scheme = scheme
+
+ def get(self, api, params, doseq=0):
+ if not (api in self.api_info):
+ raise Exception("no such api")
+ api_info = self.api_info[api]
+
+ r = self.prepare_request(api_info, params, doseq)
+
+ Signer.sign(r, self.service_info.credentials)
+
+ url = r.build(doseq)
+ resp = self.session.get(url, headers=r.headers,
+ timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+ if resp.status_code == 200:
+ return resp.text
+ else:
+ raise Exception(resp.text)
+
+ def post(self, api, params, form):
+ if not (api in self.api_info):
+ raise Exception("no such api")
+ api_info = self.api_info[api]
+ r = self.prepare_request(api_info, params)
+ r.headers['Content-Type'] = 'application/x-www-form-urlencoded'
+ r.form = self.merge(api_info.form, form)
+ r.body = urlencode(r.form, True)
+ Signer.sign(r, self.service_info.credentials)
+
+ url = r.build()
+
+ resp = self.session.post(url, headers=r.headers, data=r.form,
+ timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+ if resp.status_code == 200:
+ return resp.text
+ else:
+ raise Exception(resp.text)
+
+ def json(self, api, params, body):
+ if not (api in self.api_info):
+ raise Exception("no such api")
+ api_info = self.api_info[api]
+ r = self.prepare_request(api_info, params)
+ r.headers['Content-Type'] = 'application/json'
+ r.body = body
+
+ Signer.sign(r, self.service_info.credentials)
+
+ url = r.build()
+ resp = self.session.post(url, headers=r.headers, data=r.body,
+ timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+ if resp.status_code == 200:
+ return json.dumps(resp.json())
+ else:
+ raise Exception(resp.text.encode("utf-8"))
+
+ def put(self, url, file_path, headers):
+ with open(file_path, 'rb') as f:
+ resp = self.session.put(url, headers=headers, data=f)
+ if resp.status_code == 200:
+ return True, resp.text.encode("utf-8")
+ else:
+ return False, resp.text.encode("utf-8")
+
+ def put_data(self, url, data, headers):
+ resp = self.session.put(url, headers=headers, data=data)
+ if resp.status_code == 200:
+ return True, resp.text.encode("utf-8")
+ else:
+ return False, resp.text.encode("utf-8")
+
+ def prepare_request(self, api_info, params, doseq=0):
+ for key in params:
+ if type(params[key]) == int or type(params[key]) == float or type(params[key]) == bool:
+ params[key] = str(params[key])
+ elif type(params[key]) == list:
+ if not doseq:
+ params[key] = ','.join(params[key])
+
+ connection_timeout = self.service_info.connection_timeout
+ socket_timeout = self.service_info.socket_timeout
+
+ r = Request()
+ r.set_schema(self.service_info.scheme)
+ r.set_method(api_info.method)
+ r.set_connection_timeout(connection_timeout)
+ r.set_socket_timeout(socket_timeout)
+
+ headers = self.merge(api_info.header, self.service_info.header)
+ headers['Host'] = self.service_info.host
+ headers['User-Agent'] = 'volc-sdk-python/' + VERSION
+ r.set_headers(headers)
+
+ query = self.merge(api_info.query, params)
+ r.set_query(query)
+
+ r.set_host(self.service_info.host)
+ r.set_path(api_info.path)
+
+ return r
+
+ @staticmethod
+ def merge(param1, param2):
+ od = OrderedDict()
+ for key in param1:
+ od[key] = param1[key]
+
+ for key in param2:
+ od[key] = param2[key]
+
+ return od
+
+
+class Request:
+ def __init__(self):
+ self.schema = ''
+ self.method = ''
+ self.host = ''
+ self.path = ''
+ self.headers = OrderedDict()
+ self.query = OrderedDict()
+ self.body = ''
+ self.form = dict()
+ self.connection_timeout = 0
+ self.socket_timeout = 0
+
+ def set_schema(self, schema):
+ self.schema = schema
+
+ def set_method(self, method):
+ self.method = method
+
+ def set_host(self, host):
+ self.host = host
+
+ def set_path(self, path):
+ self.path = path
+
+ def set_headers(self, headers):
+ self.headers = headers
+
+ def set_query(self, query):
+ self.query = query
+
+ def set_body(self, body):
+ self.body = body
+
+ def set_connection_timeout(self, connection_timeout):
+ self.connection_timeout = connection_timeout
+
+ def set_socket_timeout(self, socket_timeout):
+ self.socket_timeout = socket_timeout
+
+ def build(self, doseq=0):
+ return self.schema + '://' + self.host + self.path + '?' + urlencode(self.query, doseq)
+
+
+class ServiceInfo:
+ def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme='http'):
+ self.host = host
+ self.header = header
+ self.credentials = credentials
+ self.connection_timeout = connection_timeout
+ self.socket_timeout = socket_timeout
+ self.scheme = scheme
+
+
+class ApiInfo:
+ def __init__(self, method, path, query, form, header):
+ self.method = method
+ self.path = path
+ self.query = query
+ self.form = form
+ self.header = header
+
+ def __str__(self):
+ return 'method: ' + self.method + ', path: ' + self.path
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py
new file mode 100644
index 0000000000..7eb5fdfa91
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py
@@ -0,0 +1,43 @@
+import hashlib
+import hmac
+from functools import reduce
+from urllib.parse import quote
+
+
+class Util:
+ @staticmethod
+ def norm_uri(path):
+ return quote(path).replace('%2F', '/').replace('+', '%20')
+
+ @staticmethod
+ def norm_query(params):
+ query = ''
+ for key in sorted(params.keys()):
+ if type(params[key]) == list:
+ for k in params[key]:
+ query = query + quote(key, safe='-_.~') + '=' + quote(k, safe='-_.~') + '&'
+ else:
+ query = query + quote(key, safe='-_.~') + '=' + quote(params[key], safe='-_.~') + '&'
+ query = query[:-1]
+ return query.replace('+', '%20')
+
+ @staticmethod
+ def hmac_sha256(key, content):
+ return hmac.new(key, bytes(content, encoding='utf-8'), hashlib.sha256).digest()
+
+ @staticmethod
+ def sha256(content):
+ if isinstance(content, str) is True:
+ return hashlib.sha256(content.encode('utf-8')).hexdigest()
+ else:
+ return hashlib.sha256(content).hexdigest()
+
+ @staticmethod
+ def to_hex(content):
+ lst = []
+ for ch in content:
+ hv = hex(ch).replace('0x', '')
+ if len(hv) == 1:
+ hv = '0' + hv
+ lst.append(hv)
+ return reduce(lambda x, y: x + y, lst)
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py
new file mode 100644
index 0000000000..8b14d026d9
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py
@@ -0,0 +1,79 @@
+import json
+import random
+from datetime import datetime
+
+
+class ChatRole:
+ USER = "user"
+ ASSISTANT = "assistant"
+ SYSTEM = "system"
+ FUNCTION = "function"
+
+
+class _Dict(dict):
+ __setattr__ = dict.__setitem__
+ __getattr__ = dict.__getitem__
+
+ def __missing__(self, key):
+ return None
+
+
+def dict_to_object(dict_obj):
+ # 支持嵌套类型
+ if isinstance(dict_obj, list):
+ insts = []
+ for i in dict_obj:
+ insts.append(dict_to_object(i))
+ return insts
+
+ if isinstance(dict_obj, dict):
+ inst = _Dict()
+ for k, v in dict_obj.items():
+ inst[k] = dict_to_object(v)
+ return inst
+
+ return dict_obj
+
+
+def json_to_object(json_str, req_id=None):
+ obj = dict_to_object(json.loads(json_str))
+ if obj and isinstance(obj, dict) and req_id:
+ obj["req_id"] = req_id
+ return obj
+
+
+def gen_req_id():
+ return datetime.now().strftime("%Y%m%d%H%M%S") + format(
+ random.randint(0, 2 ** 64 - 1), "020X"
+ )
+
+
+class SSEDecoder:
+ def __init__(self, source):
+ self.source = source
+
+ def _read(self):
+ data = b''
+ for chunk in self.source:
+ for line in chunk.splitlines(True):
+ data += line
+ if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')):
+ yield data
+ data = b''
+ if data:
+ yield data
+
+ def next(self):
+ for chunk in self._read():
+ for line in chunk.splitlines():
+ # skip comment
+ if line.startswith(b':'):
+ continue
+
+ if b':' in line:
+ field, value = line.split(b':', 1)
+ else:
+ field, value = line, b''
+
+ if field == b'data' and len(value) > 0:
+ yield value
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py
new file mode 100644
index 0000000000..3cbe9d9f09
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py
@@ -0,0 +1,213 @@
+import copy
+import json
+from collections.abc import Iterator
+
+from .base.auth import Credentials, Signer
+from .base.service import ApiInfo, Service, ServiceInfo
+from .common import SSEDecoder, dict_to_object, gen_req_id, json_to_object
+
+
+class MaasService(Service):
+ def __init__(self, host, region, connection_timeout=60, socket_timeout=60):
+ service_info = self.get_service_info(
+ host, region, connection_timeout, socket_timeout
+ )
+ self._apikey = None
+ api_info = self.get_api_info()
+ super().__init__(service_info, api_info)
+
+ def set_apikey(self, apikey):
+ self._apikey = apikey
+
+ @staticmethod
+ def get_service_info(host, region, connection_timeout, socket_timeout):
+ service_info = ServiceInfo(
+ host,
+ {"Accept": "application/json"},
+ Credentials("", "", "ml_maas", region),
+ connection_timeout,
+ socket_timeout,
+ "https",
+ )
+ return service_info
+
+ @staticmethod
+ def get_api_info():
+ api_info = {
+ "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}),
+ "embeddings": ApiInfo(
+ "POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {}
+ ),
+ }
+ return api_info
+
+ def chat(self, endpoint_id, req):
+ req["stream"] = False
+ return self._request(endpoint_id, "chat", req)
+
+ def stream_chat(self, endpoint_id, req):
+ req_id = gen_req_id()
+ self._validate("chat", req_id)
+ apikey = self._apikey
+
+ try:
+ req["stream"] = True
+ res = self._call(
+ endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True
+ )
+
+ decoder = SSEDecoder(res)
+
+ def iter_fn():
+ for data in decoder.next():
+ if data == b"[DONE]":
+ return
+
+ try:
+ res = json_to_object(
+ str(data, encoding="utf-8"), req_id=req_id)
+ except Exception:
+ raise
+
+ if res.error is not None and res.error.code_n != 0:
+ raise MaasException(
+ res.error.code_n,
+ res.error.code,
+ res.error.message,
+ req_id,
+ )
+ yield res
+
+ return iter_fn()
+ except MaasException:
+ raise
+ except Exception as e:
+ raise new_client_sdk_request_error(str(e))
+
+ def embeddings(self, endpoint_id, req):
+ return self._request(endpoint_id, "embeddings", req)
+
+ def _request(self, endpoint_id, api, req, params={}):
+ req_id = gen_req_id()
+
+ self._validate(api, req_id)
+
+ apikey = self._apikey
+
+ try:
+ res = self._call(endpoint_id, api, req_id, params,
+ json.dumps(req).encode("utf-8"), apikey)
+ resp = dict_to_object(res.json())
+ if resp and isinstance(resp, dict):
+ resp["req_id"] = req_id
+ return resp
+
+ except MaasException as e:
+ raise e
+ except Exception as e:
+ raise new_client_sdk_request_error(str(e), req_id)
+
+ def _validate(self, api, req_id):
+ credentials_exist = (
+ self.service_info.credentials is not None and
+ self.service_info.credentials.sk is not None and
+ self.service_info.credentials.ak is not None
+ )
+
+ if not self._apikey and not credentials_exist:
+ raise new_client_sdk_request_error("no valid credential", req_id)
+
+ if not (api in self.api_info):
+ raise new_client_sdk_request_error("no such api", req_id)
+
+ def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False):
+ api_info = copy.deepcopy(self.api_info[api])
+ api_info.path = api_info.path.format(endpoint_id=endpoint_id)
+
+ r = self.prepare_request(api_info, params)
+ r.headers["x-tt-logid"] = req_id
+ r.headers["Content-Type"] = "application/json"
+ r.body = body
+
+ if apikey is None:
+ Signer.sign(r, self.service_info.credentials)
+ elif apikey is not None:
+ r.headers["Authorization"] = "Bearer " + apikey
+
+ url = r.build()
+ res = self.session.post(
+ url,
+ headers=r.headers,
+ data=r.body,
+ timeout=(
+ self.service_info.connection_timeout,
+ self.service_info.socket_timeout,
+ ),
+ stream=stream,
+ )
+
+ if res.status_code != 200:
+ raw = res.text.encode()
+ res.close()
+ try:
+ resp = json_to_object(
+ str(raw, encoding="utf-8"), req_id=req_id)
+ except Exception:
+ raise new_client_sdk_request_error(raw, req_id)
+
+ if resp.error:
+ raise MaasException(
+ resp.error.code_n, resp.error.code, resp.error.message, req_id
+ )
+ else:
+ raise new_client_sdk_request_error(resp, req_id)
+
+ return res
+
+
+class MaasException(Exception):
+ def __init__(self, code_n, code, message, req_id):
+ self.code_n = code_n
+ self.code = code
+ self.message = message
+ self.req_id = req_id
+
+ def __str__(self):
+ return ("Detailed exception information is listed below.\n" +
+ "req_id: {}\n" +
+ "code_n: {}\n" +
+ "code: {}\n" +
+ "message: {}").format(self.req_id, self.code_n, self.code, self.message)
+
+
+def new_client_sdk_request_error(raw, req_id=""):
+ return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
+
+
+class BinaryResponseContent:
+ def __init__(self, response, request_id) -> None:
+ self.response = response
+ self.request_id = request_id
+
+ def stream_to_file(
+ self,
+ file: str
+ ) -> None:
+ is_first = True
+ error_bytes = b''
+ with open(file, mode="wb") as f:
+ for data in self.response:
+ if len(error_bytes) > 0 or (is_first and "\"error\":" in str(data)):
+ error_bytes += data
+ else:
+ f.write(data)
+
+ if len(error_bytes) > 0:
+ resp = json_to_object(
+ str(error_bytes, encoding="utf-8"), req_id=self.request_id)
+ raise MaasException(
+ resp.error.code_n, resp.error.code, resp.error.message, self.request_id
+ )
+
+ def iter_bytes(self) -> Iterator[bytes]:
+ yield from self.response
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py
new file mode 100644
index 0000000000..10f9be2d08
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py
@@ -0,0 +1,10 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class VolcengineMaaSProvider(ModelProvider):
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ pass
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml
new file mode 100644
index 0000000000..d7bcbd43f8
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml
@@ -0,0 +1,188 @@
+provider: volcengine_maas
+label:
+ en_US: Volcengine
+description:
+ en_US: Volcengine MaaS models.
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.svg
+ zh_Hans: icon_l_zh.svg
+background: "#F9FAFB"
+help:
+ title:
+ en_US: Get your Access Key and Secret Access Key from Volcengine Console
+ url:
+ en_US: https://console.volcengine.com/iam/keymanage/
+supported_model_types:
+ - llm
+ - text-embedding
+configurate_methods:
+ - customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter your Model Name
+ zh_Hans: 输入模型名称
+ credential_form_schemas:
+ - variable: volc_access_key_id
+ required: true
+ label:
+ en_US: Access Key
+ zh_Hans: Access Key
+ type: secret-input
+ placeholder:
+ en_US: Enter your Access Key
+ zh_Hans: 输入您的 Access Key
+ - variable: volc_secret_access_key
+ required: true
+ label:
+ en_US: Secret Access Key
+ zh_Hans: Secret Access Key
+ type: secret-input
+ placeholder:
+ en_US: Enter your Secret Access Key
+ zh_Hans: 输入您的 Secret Access Key
+ - variable: volc_region
+ required: true
+ label:
+ en_US: Volcengine Region
+ zh_Hans: 火山引擎地区
+ type: text-input
+ default: cn-beijing
+ placeholder:
+ en_US: Enter Volcengine Region
+ zh_Hans: 输入火山引擎地域
+ - variable: api_endpoint_host
+ required: true
+ label:
+ en_US: API Endpoint Host
+ zh_Hans: API Endpoint Host
+ type: text-input
+ default: maas-api.ml-platform-cn-beijing.volces.com
+ placeholder:
+ en_US: Enter your API Endpoint Host
+ zh_Hans: 输入 API Endpoint Host
+ - variable: endpoint_id
+ required: true
+ label:
+ en_US: Endpoint ID
+ zh_Hans: Endpoint ID
+ type: text-input
+ placeholder:
+ en_US: Enter your Endpoint ID
+ zh_Hans: 输入您的 Endpoint ID
+ - variable: base_model_name
+ label:
+ en_US: Base Model
+ zh_Hans: 基础模型
+ type: select
+ required: true
+ options:
+ - label:
+ en_US: Doubao-pro-4k
+ value: Doubao-pro-4k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-lite-4k
+ value: Doubao-lite-4k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-pro-32k
+ value: Doubao-pro-32k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-lite-32k
+ value: Doubao-lite-32k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-pro-128k
+ value: Doubao-pro-128k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-lite-128k
+ value: Doubao-lite-128k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Skylark2-pro-4k
+ value: Skylark2-pro-4k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-embedding
+ value: Doubao-embedding
+ show_on:
+ - variable: __model_type
+ value: text-embedding
+ - label:
+ en_US: Custom
+ zh_Hans: 自定义
+ value: Custom
+ - variable: mode
+ required: true
+ show_on:
+ - variable: __model_type
+ value: llm
+ - variable: base_model_name
+ value: Custom
+ label:
+ zh_Hans: 模型类型
+ en_US: Completion Mode
+ type: select
+ default: chat
+ placeholder:
+ zh_Hans: 选择对话类型
+ en_US: Select Completion Mode
+ options:
+ - value: completion
+ label:
+ en_US: Completion
+ zh_Hans: 补全
+ - value: chat
+ label:
+ en_US: Chat
+ zh_Hans: 对话
+ - variable: context_size
+ required: true
+ show_on:
+ - variable: base_model_name
+ value: Custom
+ label:
+ zh_Hans: 模型上下文长度
+ en_US: Model Context Size
+ type: text-input
+ default: '4096'
+ placeholder:
+ zh_Hans: 输入您的模型上下文长度
+ en_US: Enter your Model Context Size
+ - variable: max_tokens
+ required: true
+ show_on:
+ - variable: __model_type
+ value: llm
+ - variable: base_model_name
+ value: Custom
+ label:
+ zh_Hans: 最大 token 上限
+ en_US: Upper Bound for Max Tokens
+ default: '4096'
+ type: text-input
+ placeholder:
+ zh_Hans: 输入您的模型最大 token 上限
+ en_US: Enter your model Upper Bound for Max Tokens
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-128k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-128k.yaml
new file mode 100644
index 0000000000..b1b1ba1f69
--- /dev/null
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-128k.yaml
@@ -0,0 +1,37 @@
+model: ernie-3.5-128k
+label:
+ en_US: Ernie-3.5-128K
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 128000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ min: 0.1
+ max: 1.0
+ default: 0.8
+ - name: top_p
+ use_template: top_p
+ - name: max_tokens
+ use_template: max_tokens
+ default: 4096
+ min: 2
+ max: 4096
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ - name: response_format
+ use_template: response_format
+ - name: disable_search
+ label:
+ zh_Hans: 禁用搜索
+ en_US: Disable Search
+ type: boolean
+ help:
+ zh_Hans: 禁用模型自行进行外部搜索。
+ en_US: Disable the model to perform external search.
+ required: false
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml
index 9487342a1d..1e8cf96440 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml
@@ -35,3 +35,4 @@ parameter_rules:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false
+deprecated: true
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-1222.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-1222.yaml
index 5dfcd5825b..c43588cfe1 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-1222.yaml
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-1222.yaml
@@ -35,3 +35,4 @@ parameter_rules:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false
+deprecated: true
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k-0321.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k-0321.yaml
new file mode 100644
index 0000000000..52e1dc832d
--- /dev/null
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k-0321.yaml
@@ -0,0 +1,30 @@
+model: ernie-character-8k-0321
+label:
+ en_US: ERNIE-Character-8K-0321
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 8192
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ min: 0.1
+ max: 1.0
+ default: 0.95
+ - name: top_p
+ use_template: top_p
+ min: 0
+ max: 1.0
+ default: 0.7
+ - name: max_tokens
+ use_template: max_tokens
+ default: 1024
+ min: 2
+ max: 1024
+ - name: presence_penalty
+ use_template: presence_penalty
+ default: 1.0
+ min: 1.0
+ max: 2.0
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml
index 3f09f10d1a..78325c1d64 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml
@@ -22,7 +22,7 @@ parameter_rules:
use_template: max_tokens
default: 1024
min: 2
- max: 1024
+ max: 2048
- name: presence_penalty
use_template: presence_penalty
default: 1.0
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-128k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-128k.yaml
index 3b8885c862..331639624c 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-128k.yaml
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-128k.yaml
@@ -20,9 +20,9 @@ parameter_rules:
default: 0.7
- name: max_tokens
use_template: max_tokens
- default: 1024
+ default: 4096
min: 2
- max: 1024
+ max: 4096
- name: presence_penalty
use_template: presence_penalty
default: 1.0
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-8k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-8k.yaml
index 25d32c9f8a..304c6d1f7e 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-8k.yaml
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-8k.yaml
@@ -22,7 +22,7 @@ parameter_rules:
use_template: max_tokens
default: 1024
min: 2
- max: 1024
+ max: 2048
- name: presence_penalty
use_template: presence_penalty
default: 1.0
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py
index 091337c33d..4646ba384a 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py
@@ -129,12 +129,14 @@ class ErnieBotModel:
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
+ 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
+ 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
}
function_calling_supports = [
@@ -143,7 +145,9 @@ class ErnieBotModel:
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
- 'ernie-3.5-4k-0205'
+ 'ernie-3.5-4k-0205',
+ 'ernie-3.5-128k',
+ 'ernie-4.0-8k'
]
api_key: str = ''
diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py
index 602d0b749f..cc3ce17975 100644
--- a/api/core/model_runtime/model_providers/xinference/llm/llm.py
+++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py
@@ -28,7 +28,10 @@ from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
+ ImagePromptMessageContent,
PromptMessage,
+ PromptMessageContent,
+ PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
@@ -61,8 +64,8 @@ from core.model_runtime.utils import helper
class XinferenceAILargeLanguageModel(LargeLanguageModel):
- def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
"""
@@ -99,7 +102,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
try:
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
-
+
extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
@@ -111,10 +114,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
credentials['completion_type'] = 'completion'
else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
-
+
if extra_param.support_function_call:
credentials['support_function_call'] = True
+ if extra_param.support_vision:
+ credentials['support_vision'] = True
+
if extra_param.context_length:
credentials['context_length'] = extra_param.context_length
@@ -135,7 +141,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
"""
return self._num_tokens_from_messages(prompt_messages, tools)
- def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
+ def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
is_completion_model: bool = False) -> int:
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
@@ -155,7 +161,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
text = ''
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
- text += item.text
+ text += item['text']
value = text
@@ -191,7 +197,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
-
+
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling
@@ -234,7 +240,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
num_tokens += tokens(required_field)
return num_tokens
-
+
def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
"""
convert prompt message to text
@@ -260,7 +266,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
- raise ValueError("User message content must be str")
+ sub_messages = []
+ for message_content in message.content:
+ if message_content.type == PromptMessageContentType.TEXT:
+ message_content = cast(PromptMessageContent, message_content)
+ sub_message_dict = {
+ "type": "text",
+ "text": message_content.data
+ }
+ sub_messages.append(sub_message_dict)
+ elif message_content.type == PromptMessageContentType.IMAGE:
+ message_content = cast(ImagePromptMessageContent, message_content)
+ sub_message_dict = {
+ "type": "image_url",
+ "image_url": {
+ "url": message_content.data,
+ "detail": message_content.detail.value
+ }
+ }
+ sub_messages.append(sub_message_dict)
+ message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
@@ -277,7 +302,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
-
+
return message_dict
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
@@ -338,8 +363,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
-
+
+
+ features = []
+
support_function_call = credentials.get('support_function_call', False)
+ if support_function_call:
+ features.append(ModelFeature.TOOL_CALL)
+
+ support_vision = credentials.get('support_vision', False)
+ if support_vision:
+ features.append(ModelFeature.VISION)
+
context_length = credentials.get('context_length', 2048)
entity = AIModelEntity(
@@ -349,10 +384,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
- features=[
- ModelFeature.TOOL_CALL
- ] if support_function_call else [],
- model_properties={
+ features=features,
+ model_properties={
ModelPropertyKey.MODE: completion_type,
ModelPropertyKey.CONTEXT_SIZE: context_length
},
@@ -360,22 +393,22 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
)
return entity
-
- def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+
+ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
- tools: list[PromptMessageTool] | None = None,
+ tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
"""
generate text from LLM
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate`
-
+
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('server_url is required in credentials')
-
+
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
@@ -408,11 +441,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
'function': helper.dump_model(tool)
} for tool in tools
]
-
+ vision = credentials.get('support_vision', False)
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
resp = client.chat.completions.create(
model=credentials['model_uid'],
- messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
+ messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
stream=stream,
user=user,
**generate_config,
@@ -497,7 +530,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
"""
if len(resp.choices) == 0:
raise InvokeServerUnavailableError("Empty response")
-
+
assistant_message = resp.choices[0].message
# convert tool call to assistant message tool call
@@ -527,7 +560,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
)
return response
-
+
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: Iterator[ChatCompletionChunk]) -> Generator:
@@ -544,7 +577,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
continue
-
+
# check if there is a tool call in the response
function_call = None
tool_calls = []
@@ -573,9 +606,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
- usage = self._calc_response_usage(model=model, credentials=credentials,
+ usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
-
+
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
@@ -608,7 +641,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
"""
if len(resp.choices) == 0:
raise InvokeServerUnavailableError("Empty response")
-
+
assistant_message = resp.choices[0].text
# transform assistant message to prompt message
@@ -670,9 +703,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True
)
- usage = self._calc_response_usage(model=model, credentials=credentials,
+ usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
-
+
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py
index 66dab65804..9a3fc9b193 100644
--- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py
+++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py
@@ -14,13 +14,15 @@ class XinferenceModelExtraParameter:
max_tokens: int = 512
context_length: int = 2048
support_function_call: bool = False
+ support_vision: bool = False
- def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
- support_function_call: bool, max_tokens: int, context_length: int) -> None:
+ def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
+ support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None:
self.model_format = model_format
self.model_handle_type = model_handle_type
self.model_ability = model_ability
self.support_function_call = support_function_call
+ self.support_vision = support_vision
self.max_tokens = max_tokens
self.context_length = context_length
@@ -71,7 +73,7 @@ class XinferenceHelper:
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200:
raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
-
+
response_json = response.json()
model_format = response_json.get('model_format', 'ggmlv3')
@@ -87,17 +89,19 @@ class XinferenceHelper:
model_handle_type = 'chat'
else:
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
-
+
support_function_call = 'tools' in model_ability
+ support_vision = 'vision' in model_ability
max_tokens = response_json.get('max_tokens', 512)
context_length = response_json.get('context_length', 2048)
-
+
return XinferenceModelExtraParameter(
model_format=model_format,
model_handle_type=model_handle_type,
model_ability=model_ability,
support_function_call=support_function_call,
+ support_vision=support_vision,
max_tokens=max_tokens,
context_length=context_length
)
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/yi/llm/_position.yaml b/api/core/model_runtime/model_providers/yi/llm/_position.yaml
index 12838d670f..e876893b41 100644
--- a/api/core/model_runtime/model_providers/yi/llm/_position.yaml
+++ b/api/core/model_runtime/model_providers/yi/llm/_position.yaml
@@ -1,3 +1,9 @@
- yi-34b-chat-0205
- yi-34b-chat-200k
- yi-vl-plus
+- yi-large
+- yi-medium
+- yi-vision
+- yi-medium-200k
+- yi-spark
+- yi-large-turbo
diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-large-turbo.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-large-turbo.yaml
new file mode 100644
index 0000000000..1d00eca2ca
--- /dev/null
+++ b/api/core/model_runtime/model_providers/yi/llm/yi-large-turbo.yaml
@@ -0,0 +1,43 @@
+model: yi-large-turbo
+label:
+ zh_Hans: yi-large-turbo
+ en_US: yi-large-turbo
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 16384
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ type: float
+ default: 0.3
+ min: 0.0
+ max: 2.0
+ help:
+ zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
+ en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
+ - name: max_tokens
+ use_template: max_tokens
+ type: int
+ default: 1024
+ min: 1
+ max: 16384
+ help:
+ zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
+ en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
+ - name: top_p
+ use_template: top_p
+ type: float
+ default: 0.9
+ min: 0.01
+ max: 1.00
+ help:
+ zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
+ en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
+pricing:
+ input: '12'
+ output: '12'
+ unit: '0.000001'
+ currency: RMB
diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-large.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-large.yaml
new file mode 100644
index 0000000000..347f511280
--- /dev/null
+++ b/api/core/model_runtime/model_providers/yi/llm/yi-large.yaml
@@ -0,0 +1,43 @@
+model: yi-large
+label:
+ zh_Hans: yi-large
+ en_US: yi-large
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 16384
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ type: float
+ default: 0.3
+ min: 0.0
+ max: 2.0
+ help:
+ zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
+ en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
+ - name: max_tokens
+ use_template: max_tokens
+ type: int
+ default: 1024
+ min: 1
+ max: 16384
+ help:
+ zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
+ en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
+ - name: top_p
+ use_template: top_p
+ type: float
+ default: 0.9
+ min: 0.01
+ max: 1.00
+ help:
+ zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
+ en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
+pricing:
+ input: '20'
+ output: '20'
+ unit: '0.000001'
+ currency: RMB
diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-medium-200k.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-medium-200k.yaml
new file mode 100644
index 0000000000..e8ddbcba97
--- /dev/null
+++ b/api/core/model_runtime/model_providers/yi/llm/yi-medium-200k.yaml
@@ -0,0 +1,43 @@
+model: yi-medium-200k
+label:
+ zh_Hans: yi-medium-200k
+ en_US: yi-medium-200k
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 204800
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ type: float
+ default: 0.3
+ min: 0.0
+ max: 2.0
+ help:
+ zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
+ en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
+ - name: max_tokens
+ use_template: max_tokens
+ type: int
+ default: 1024
+ min: 1
+ max: 204800
+ help:
+ zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
+ en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
+ - name: top_p
+ use_template: top_p
+ type: float
+ default: 0.9
+ min: 0.01
+ max: 1.00
+ help:
+ zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
+ en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
+pricing:
+ input: '12'
+ output: '12'
+ unit: '0.000001'
+ currency: RMB
diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-medium.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-medium.yaml
new file mode 100644
index 0000000000..4f0244d1f5
--- /dev/null
+++ b/api/core/model_runtime/model_providers/yi/llm/yi-medium.yaml
@@ -0,0 +1,43 @@
+model: yi-medium
+label:
+ zh_Hans: yi-medium
+ en_US: yi-medium
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 16384
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ type: float
+ default: 0.3
+ min: 0.0
+ max: 2.0
+ help:
+ zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
+ en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
+ - name: max_tokens
+ use_template: max_tokens
+ type: int
+ default: 1024
+ min: 1
+ max: 16384
+ help:
+ zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
+ en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
+ - name: top_p
+ use_template: top_p
+ type: float
+ default: 0.9
+ min: 0.01
+ max: 1.00
+ help:
+ zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
+ en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
+pricing:
+ input: '2.5'
+ output: '2.5'
+ unit: '0.000001'
+ currency: RMB
diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-spark.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-spark.yaml
new file mode 100644
index 0000000000..e28e9fd8c0
--- /dev/null
+++ b/api/core/model_runtime/model_providers/yi/llm/yi-spark.yaml
@@ -0,0 +1,43 @@
+model: yi-spark
+label:
+ zh_Hans: yi-spark
+ en_US: yi-spark
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 16384
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ type: float
+ default: 0.3
+ min: 0.0
+ max: 2.0
+ help:
+ zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
+ en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
+ - name: max_tokens
+ use_template: max_tokens
+ type: int
+ default: 1024
+ min: 1
+ max: 16384
+ help:
+ zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
+ en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
+ - name: top_p
+ use_template: top_p
+ type: float
+ default: 0.9
+ min: 0.01
+ max: 1.00
+ help:
+ zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
+ en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
+pricing:
+ input: '1'
+ output: '1'
+ unit: '0.000001'
+ currency: RMB
diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-vision.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-vision.yaml
new file mode 100644
index 0000000000..bce34f5836
--- /dev/null
+++ b/api/core/model_runtime/model_providers/yi/llm/yi-vision.yaml
@@ -0,0 +1,44 @@
+model: yi-vision
+label:
+ zh_Hans: yi-vision
+ en_US: yi-vision
+model_type: llm
+features:
+ - agent-thought
+ - vision
+model_properties:
+ mode: chat
+ context_size: 4096
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ type: float
+ default: 0.3
+ min: 0.0
+ max: 2.0
+ help:
+ zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
+ en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
+ - name: max_tokens
+ use_template: max_tokens
+ type: int
+ default: 1024
+ min: 1
+ max: 4096
+ help:
+ zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
+ en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
+ - name: top_p
+ use_template: top_p
+ type: float
+ default: 0.9
+ min: 0.01
+ max: 1.00
+ help:
+ zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
+ en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
+pricing:
+ input: '6'
+ output: '6'
+ unit: '0.000001'
+ currency: RMB
diff --git a/api/core/model_runtime/model_providers/yi/yi.yaml b/api/core/model_runtime/model_providers/yi/yi.yaml
index a8c0d857b6..de741afb10 100644
--- a/api/core/model_runtime/model_providers/yi/yi.yaml
+++ b/api/core/model_runtime/model_providers/yi/yi.yaml
@@ -33,7 +33,7 @@ provider_credential_schema:
- variable: endpoint_url
label:
zh_Hans: 自定义 API endpoint 地址
- en_US: CUstom API endpoint URL
+ en_US: Custom API endpoint URL
type: text-input
required: false
placeholder:
diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py
index 29b516ac02..22420fea2c 100644
--- a/api/core/prompt/advanced_prompt_transform.py
+++ b/api/core/prompt/advanced_prompt_transform.py
@@ -2,6 +2,7 @@ from typing import Optional, Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileVar
+from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@@ -80,29 +81,35 @@ class AdvancedPromptTransform(PromptTransform):
prompt_messages = []
- prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
- prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+ if prompt_template.edition_type == 'basic' or not prompt_template.edition_type:
+ prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
+ prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
- prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
+ prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
- if memory and memory_config:
- role_prefix = memory_config.role_prefix
- prompt_inputs = self._set_histories_variable(
- memory=memory,
- memory_config=memory_config,
- raw_prompt=raw_prompt,
- role_prefix=role_prefix,
- prompt_template=prompt_template,
- prompt_inputs=prompt_inputs,
- model_config=model_config
+ if memory and memory_config:
+ role_prefix = memory_config.role_prefix
+ prompt_inputs = self._set_histories_variable(
+ memory=memory,
+ memory_config=memory_config,
+ raw_prompt=raw_prompt,
+ role_prefix=role_prefix,
+ prompt_template=prompt_template,
+ prompt_inputs=prompt_inputs,
+ model_config=model_config
+ )
+
+ if query:
+ prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
+
+ prompt = prompt_template.format(
+ prompt_inputs
)
+ else:
+ prompt = raw_prompt
+ prompt_inputs = inputs
- if query:
- prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
-
- prompt = prompt_template.format(
- prompt_inputs
- )
+ prompt = Jinja2Formatter.format(prompt, prompt_inputs)
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
@@ -135,14 +142,22 @@ class AdvancedPromptTransform(PromptTransform):
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item.text
- prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
- prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+ if prompt_item.edition_type == 'basic' or not prompt_item.edition_type:
+ prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
+ prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
- prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
+ prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
- prompt = prompt_template.format(
- prompt_inputs
- )
+ prompt = prompt_template.format(
+ prompt_inputs
+ )
+ elif prompt_item.edition_type == 'jinja2':
+ prompt = raw_prompt
+ prompt_inputs = inputs
+
+ prompt = Jinja2Formatter.format(prompt, prompt_inputs)
+ else:
+ raise ValueError(f'Invalid edition type: {prompt_item.edition_type}')
if prompt_item.role == PromptMessageRole.USER:
prompt_messages.append(UserPromptMessage(content=prompt))
diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py
new file mode 100644
index 0000000000..af0075ea91
--- /dev/null
+++ b/api/core/prompt/agent_history_prompt_transform.py
@@ -0,0 +1,82 @@
+from typing import Optional, cast
+
+from core.app.entities.app_invoke_entities import (
+ ModelConfigWithCredentialsEntity,
+)
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.entities.message_entities import (
+ PromptMessage,
+ SystemPromptMessage,
+ UserPromptMessage,
+)
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.prompt_transform import PromptTransform
+
+
+class AgentHistoryPromptTransform(PromptTransform):
+ """
+ History Prompt Transform for Agent App
+ """
+ def __init__(self,
+ model_config: ModelConfigWithCredentialsEntity,
+ prompt_messages: list[PromptMessage],
+ history_messages: list[PromptMessage],
+ memory: Optional[TokenBufferMemory] = None,
+ ):
+ self.model_config = model_config
+ self.prompt_messages = prompt_messages
+ self.history_messages = history_messages
+ self.memory = memory
+
+ def get_prompt(self) -> list[PromptMessage]:
+ prompt_messages = []
+ num_system = 0
+ for prompt_message in self.history_messages:
+ if isinstance(prompt_message, SystemPromptMessage):
+ prompt_messages.append(prompt_message)
+ num_system += 1
+
+ if not self.memory:
+ return prompt_messages
+
+ max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
+
+ model_type_instance = self.model_config.provider_model_bundle.model_type_instance
+ model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+ curr_message_tokens = model_type_instance.get_num_tokens(
+ self.memory.model_instance.model,
+ self.memory.model_instance.credentials,
+ self.history_messages
+ )
+ if curr_message_tokens <= max_token_limit:
+ return self.history_messages
+
+ # number of prompt has been appended in current message
+ num_prompt = 0
+ # append prompt messages in desc order
+ for prompt_message in self.history_messages[::-1]:
+ if isinstance(prompt_message, SystemPromptMessage):
+ continue
+ prompt_messages.append(prompt_message)
+ num_prompt += 1
+ # a message is start with UserPromptMessage
+ if isinstance(prompt_message, UserPromptMessage):
+ curr_message_tokens = model_type_instance.get_num_tokens(
+ self.memory.model_instance.model,
+ self.memory.model_instance.credentials,
+ prompt_messages
+ )
+ # if current message token is overflow, drop all the prompts in current message and break
+ if curr_message_tokens > max_token_limit:
+ prompt_messages = prompt_messages[:-num_prompt]
+ break
+ num_prompt = 0
+ # return prompt messages in asc order
+ message_prompts = prompt_messages[num_system:]
+ message_prompts.reverse()
+
+ # merge system and message prompt
+ prompt_messages = prompt_messages[:num_system]
+ prompt_messages.extend(message_prompts)
+ return prompt_messages
diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py
index 2be00bdf0e..23a8602bea 100644
--- a/api/core/prompt/entities/advanced_prompt_entities.py
+++ b/api/core/prompt/entities/advanced_prompt_entities.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Literal, Optional
from pydantic import BaseModel
@@ -11,6 +11,7 @@ class ChatModelMessage(BaseModel):
"""
text: str
role: PromptMessageRole
+ edition_type: Optional[Literal['basic', 'jinja2']]
class CompletionModelPromptTemplate(BaseModel):
@@ -18,6 +19,7 @@ class CompletionModelPromptTemplate(BaseModel):
Completion Model Prompt Template.
"""
text: str
+ edition_type: Optional[Literal['basic', 'jinja2']]
class MemoryConfig(BaseModel):
diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py
index 9bf2ae090f..d8e2d2f76d 100644
--- a/api/core/prompt/prompt_transform.py
+++ b/api/core/prompt/prompt_transform.py
@@ -1,10 +1,10 @@
-from typing import Optional, cast
+from typing import Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
@@ -25,12 +25,12 @@ class PromptTransform:
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
- model_type_instance = model_config.provider_model_bundle.model_type_instance
- model_type_instance = cast(LargeLanguageModel, model_type_instance)
+ model_instance = ModelInstance(
+ provider_model_bundle=model_config.provider_model_bundle,
+ model=model_config.model
+ )
- curr_message_tokens = model_type_instance.get_num_tokens(
- model_config.model,
- model_config.credentials,
+ curr_message_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py
index 5fceeb3595..befdceeda5 100644
--- a/api/core/prompt/utils/prompt_message_util.py
+++ b/api/core/prompt/utils/prompt_message_util.py
@@ -1,6 +1,7 @@
from typing import cast
from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
@@ -21,13 +22,25 @@ class PromptMessageUtil:
"""
prompts = []
if model_mode == ModelMode.CHAT.value:
+ tool_calls = []
for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER:
role = 'user'
elif prompt_message.role == PromptMessageRole.ASSISTANT:
role = 'assistant'
+ if isinstance(prompt_message, AssistantPromptMessage):
+ tool_calls = [{
+ 'id': tool_call.id,
+ 'type': 'function',
+ 'function': {
+ 'name': tool_call.function.name,
+ 'arguments': tool_call.function.arguments,
+ }
+ } for tool_call in prompt_message.tool_calls]
elif prompt_message.role == PromptMessageRole.SYSTEM:
role = 'system'
+ elif prompt_message.role == PromptMessageRole.TOOL:
+ role = 'tool'
else:
continue
@@ -48,11 +61,16 @@ class PromptMessageUtil:
else:
text = prompt_message.content
- prompts.append({
+ prompt = {
"role": role,
"text": text,
"files": files
- })
+ }
+
+ if tool_calls:
+ prompt['tool_calls'] = tool_calls
+
+ prompts.append(prompt)
else:
prompt_message = prompt_messages[0]
text = ''
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index 0db84d3b69..c9447a79df 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -11,6 +11,8 @@ from core.entities.provider_entities import (
CustomConfiguration,
CustomModelConfiguration,
CustomProviderConfiguration,
+ ModelLoadBalancingConfiguration,
+ ModelSettings,
QuotaConfiguration,
SystemConfiguration,
)
@@ -25,14 +27,18 @@ from core.model_runtime.entities.provider_entities import (
from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider
from extensions.ext_database import db
+from extensions.ext_redis import redis_client
from models.provider import (
+ LoadBalancingModelConfig,
Provider,
ProviderModel,
+ ProviderModelSetting,
ProviderQuotaType,
ProviderType,
TenantDefaultModel,
TenantPreferredModelProvider,
)
+from services.feature_service import FeatureService
class ProviderManager:
@@ -98,6 +104,13 @@ class ProviderManager:
# Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
+ # Get All provider model settings
+ provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
+
+ # Get All load balancing configs
+ provider_name_to_provider_load_balancing_model_configs_dict \
+ = self._get_all_provider_load_balancing_configs(tenant_id)
+
provider_configurations = ProviderConfigurations(
tenant_id=tenant_id
)
@@ -105,14 +118,8 @@ class ProviderManager:
# Construct ProviderConfiguration objects for each provider
for provider_entity in provider_entities:
provider_name = provider_entity.provider
-
- provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider)
- if not provider_records:
- provider_records = []
-
- provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider)
- if not provider_model_records:
- provider_model_records = []
+ provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
+ provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
# Convert to custom configuration
custom_configuration = self._to_custom_configuration(
@@ -134,38 +141,38 @@ class ProviderManager:
if preferred_provider_type_record:
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
+ elif custom_configuration.provider or custom_configuration.models:
+ preferred_provider_type = ProviderType.CUSTOM
+ elif system_configuration.enabled:
+ preferred_provider_type = ProviderType.SYSTEM
else:
- if custom_configuration.provider or custom_configuration.models:
- preferred_provider_type = ProviderType.CUSTOM
- elif system_configuration.enabled:
- preferred_provider_type = ProviderType.SYSTEM
- else:
- preferred_provider_type = ProviderType.CUSTOM
+ preferred_provider_type = ProviderType.CUSTOM
using_provider_type = preferred_provider_type
+ has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations)
+
if preferred_provider_type == ProviderType.SYSTEM:
- if not system_configuration.enabled:
+ if not system_configuration.enabled or not has_valid_quota:
using_provider_type = ProviderType.CUSTOM
- has_valid_quota = False
- for quota_configuration in system_configuration.quota_configurations:
- if quota_configuration.is_valid:
- has_valid_quota = True
- break
-
- if not has_valid_quota:
- using_provider_type = ProviderType.CUSTOM
else:
if not custom_configuration.provider and not custom_configuration.models:
- if system_configuration.enabled:
- has_valid_quota = False
- for quota_configuration in system_configuration.quota_configurations:
- if quota_configuration.is_valid:
- has_valid_quota = True
- break
+ if system_configuration.enabled and has_valid_quota:
+ using_provider_type = ProviderType.SYSTEM
- if has_valid_quota:
- using_provider_type = ProviderType.SYSTEM
+ # Get provider load balancing configs
+ provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name)
+
+ # Get provider load balancing configs
+ provider_load_balancing_configs \
+ = provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name)
+
+ # Convert to model settings
+ model_settings = self._to_model_settings(
+ provider_entity=provider_entity,
+ provider_model_settings=provider_model_settings,
+ load_balancing_model_configs=provider_load_balancing_configs
+ )
provider_configuration = ProviderConfiguration(
tenant_id=tenant_id,
@@ -173,7 +180,8 @@ class ProviderManager:
preferred_provider_type=preferred_provider_type,
using_provider_type=using_provider_type,
system_configuration=system_configuration,
- custom_configuration=custom_configuration
+ custom_configuration=custom_configuration,
+ model_settings=model_settings
)
provider_configurations[provider_name] = provider_configuration
@@ -233,30 +241,17 @@ class ProviderManager:
)
if available_models:
- found = False
- for available_model in available_models:
- if available_model.model == "gpt-4":
- default_model = TenantDefaultModel(
- tenant_id=tenant_id,
- model_type=model_type.to_origin_model_type(),
- provider_name=available_model.provider.provider,
- model_name=available_model.model
- )
- db.session.add(default_model)
- db.session.commit()
- found = True
- break
+ available_model = next((model for model in available_models if model.model == "gpt-4"),
+ available_models[0])
- if not found:
- available_model = available_models[0]
- default_model = TenantDefaultModel(
- tenant_id=tenant_id,
- model_type=model_type.to_origin_model_type(),
- provider_name=available_model.provider.provider,
- model_name=available_model.model
- )
- db.session.add(default_model)
- db.session.commit()
+ default_model = TenantDefaultModel(
+ tenant_id=tenant_id,
+ model_type=model_type.to_origin_model_type(),
+ provider_name=available_model.provider.provider,
+ model_name=available_model.model
+ )
+ db.session.add(default_model)
+ db.session.commit()
if not default_model:
return None
@@ -371,7 +366,7 @@ class ProviderManager:
"""
Get All preferred provider types of the workspace.
- :param tenant_id:
+ :param tenant_id: workspace id
:return:
"""
preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
@@ -386,6 +381,56 @@ class ProviderManager:
return provider_name_to_preferred_provider_type_records_dict
+ def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]:
+ """
+ Get All provider model settings of the workspace.
+
+ :param tenant_id: workspace id
+ :return:
+ """
+ provider_model_settings = db.session.query(ProviderModelSetting) \
+ .filter(
+ ProviderModelSetting.tenant_id == tenant_id
+ ).all()
+
+ provider_name_to_provider_model_settings_dict = defaultdict(list)
+ for provider_model_setting in provider_model_settings:
+ (provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name]
+ .append(provider_model_setting))
+
+ return provider_name_to_provider_model_settings_dict
+
+ def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
+ """
+ Get All provider load balancing configs of the workspace.
+
+ :param tenant_id: workspace id
+ :return:
+ """
+ cache_key = f"tenant:{tenant_id}:model_load_balancing_enabled"
+ cache_result = redis_client.get(cache_key)
+ if cache_result is None:
+ model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled
+ redis_client.setex(cache_key, 120, str(model_load_balancing_enabled))
+ else:
+ cache_result = cache_result.decode('utf-8')
+ model_load_balancing_enabled = cache_result == 'True'
+
+ if not model_load_balancing_enabled:
+ return dict()
+
+ provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
+ .filter(
+ LoadBalancingModelConfig.tenant_id == tenant_id
+ ).all()
+
+ provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
+ for provider_load_balancing_config in provider_load_balancing_configs:
+ (provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name]
+ .append(provider_load_balancing_config))
+
+ return provider_name_to_provider_load_balancing_model_configs_dict
+
def _init_trial_provider_records(self, tenant_id: str,
provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]:
"""
@@ -759,3 +804,97 @@ class ProviderManager:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables
+
+ def _to_model_settings(self, provider_entity: ProviderEntity,
+ provider_model_settings: Optional[list[ProviderModelSetting]] = None,
+ load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \
+ -> list[ModelSettings]:
+ """
+ Convert to model settings.
+
+ :param provider_model_settings: provider model settings include enabled, load balancing enabled
+ :param load_balancing_model_configs: load balancing model configs
+ :return:
+ """
+ # Get provider model credential secret variables
+ model_credential_secret_variables = self._extract_secret_variables(
+ provider_entity.model_credential_schema.credential_form_schemas
+ if provider_entity.model_credential_schema else []
+ )
+
+ model_settings = []
+ if not provider_model_settings:
+ return model_settings
+
+ for provider_model_setting in provider_model_settings:
+ load_balancing_configs = []
+ if provider_model_setting.load_balancing_enabled and load_balancing_model_configs:
+ for load_balancing_model_config in load_balancing_model_configs:
+ if (load_balancing_model_config.model_name == provider_model_setting.model_name
+ and load_balancing_model_config.model_type == provider_model_setting.model_type):
+ if not load_balancing_model_config.enabled:
+ continue
+
+ if not load_balancing_model_config.encrypted_config:
+ if load_balancing_model_config.name == "__inherit__":
+ load_balancing_configs.append(ModelLoadBalancingConfiguration(
+ id=load_balancing_model_config.id,
+ name=load_balancing_model_config.name,
+ credentials={}
+ ))
+ continue
+
+ provider_model_credentials_cache = ProviderCredentialsCache(
+ tenant_id=load_balancing_model_config.tenant_id,
+ identity_id=load_balancing_model_config.id,
+ cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
+ )
+
+ # Get cached provider model credentials
+ cached_provider_model_credentials = provider_model_credentials_cache.get()
+
+ if not cached_provider_model_credentials:
+ try:
+ provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config)
+ except JSONDecodeError:
+ continue
+
+ # Get decoding rsa key and cipher for decrypting credentials
+ if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
+ self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(
+ load_balancing_model_config.tenant_id)
+
+ for variable in model_credential_secret_variables:
+ if variable in provider_model_credentials:
+ try:
+ provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
+ provider_model_credentials.get(variable),
+ self.decoding_rsa_key,
+ self.decoding_cipher_rsa
+ )
+ except ValueError:
+ pass
+
+ # cache provider model credentials
+ provider_model_credentials_cache.set(
+ credentials=provider_model_credentials
+ )
+ else:
+ provider_model_credentials = cached_provider_model_credentials
+
+ load_balancing_configs.append(ModelLoadBalancingConfiguration(
+ id=load_balancing_model_config.id,
+ name=load_balancing_model_config.name,
+ credentials=provider_model_credentials
+ ))
+
+ model_settings.append(
+ ModelSettings(
+ model=provider_model_setting.model_name,
+ model_type=ModelType.value_of(provider_model_setting.model_type),
+ enabled=provider_model_setting.enabled,
+ load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else []
+ )
+ )
+
+ return model_settings
diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py
index bdd69c27b1..a0f2947784 100644
--- a/api/core/rag/data_post_processor/data_post_processor.py
+++ b/api/core/rag/data_post_processor/data_post_processor.py
@@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.models.document import Document
-from core.rerank.rerank import RerankRunner
+from core.rag.rerank.rerank import RerankRunner
class DataPostProcessor:
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index 0f9c753056..dd74406f30 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -33,6 +33,7 @@ class RetrievalService:
return []
all_documents = []
threads = []
+ exceptions = []
# retrieval_model source with keyword
if retrival_method == 'keyword_search':
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
@@ -40,7 +41,8 @@ class RetrievalService:
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
- 'all_documents': all_documents
+ 'all_documents': all_documents,
+ 'exceptions': exceptions,
})
threads.append(keyword_thread)
keyword_thread.start()
@@ -54,7 +56,8 @@ class RetrievalService:
'score_threshold': score_threshold,
'reranking_model': reranking_model,
'all_documents': all_documents,
- 'retrival_method': retrival_method
+ 'retrival_method': retrival_method,
+ 'exceptions': exceptions,
})
threads.append(embedding_thread)
embedding_thread.start()
@@ -69,7 +72,8 @@ class RetrievalService:
'score_threshold': score_threshold,
'top_k': top_k,
'reranking_model': reranking_model,
- 'all_documents': all_documents
+ 'all_documents': all_documents,
+ 'exceptions': exceptions,
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
@@ -77,6 +81,10 @@ class RetrievalService:
for thread in threads:
thread.join()
+ if exceptions:
+ exception_message = ';\n'.join(exceptions)
+ raise Exception(exception_message)
+
if retrival_method == 'hybrid_search':
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents = data_post_processor.invoke(
@@ -89,82 +97,91 @@ class RetrievalService:
@classmethod
def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
- top_k: int, all_documents: list):
+ top_k: int, all_documents: list, exceptions: list):
with flask_app.app_context():
- dataset = db.session.query(Dataset).filter(
- Dataset.id == dataset_id
- ).first()
+ try:
+ dataset = db.session.query(Dataset).filter(
+ Dataset.id == dataset_id
+ ).first()
- keyword = Keyword(
- dataset=dataset
- )
+ keyword = Keyword(
+ dataset=dataset
+ )
- documents = keyword.search(
- query,
- top_k=top_k
- )
- all_documents.extend(documents)
+ documents = keyword.search(
+ query,
+ top_k=top_k
+ )
+ all_documents.extend(documents)
+ except Exception as e:
+ exceptions.append(str(e))
@classmethod
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
- all_documents: list, retrival_method: str):
+ all_documents: list, retrival_method: str, exceptions: list):
with flask_app.app_context():
- dataset = db.session.query(Dataset).filter(
- Dataset.id == dataset_id
- ).first()
+ try:
+ dataset = db.session.query(Dataset).filter(
+ Dataset.id == dataset_id
+ ).first()
- vector = Vector(
- dataset=dataset
- )
+ vector = Vector(
+ dataset=dataset
+ )
- documents = vector.search_by_vector(
- query,
- search_type='similarity_score_threshold',
- top_k=top_k,
- score_threshold=score_threshold,
- filter={
- 'group_id': [dataset.id]
- }
- )
+ documents = vector.search_by_vector(
+ query,
+ search_type='similarity_score_threshold',
+ top_k=top_k,
+ score_threshold=score_threshold,
+ filter={
+ 'group_id': [dataset.id]
+ }
+ )
- if documents:
- if reranking_model and retrival_method == 'semantic_search':
- data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
- all_documents.extend(data_post_processor.invoke(
- query=query,
- documents=documents,
- score_threshold=score_threshold,
- top_n=len(documents)
- ))
- else:
- all_documents.extend(documents)
+ if documents:
+ if reranking_model and retrival_method == 'semantic_search':
+ data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+ all_documents.extend(data_post_processor.invoke(
+ query=query,
+ documents=documents,
+ score_threshold=score_threshold,
+ top_n=len(documents)
+ ))
+ else:
+ all_documents.extend(documents)
+ except Exception as e:
+ exceptions.append(str(e))
@classmethod
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
- all_documents: list, retrival_method: str):
+ all_documents: list, retrival_method: str, exceptions: list):
with flask_app.app_context():
- dataset = db.session.query(Dataset).filter(
- Dataset.id == dataset_id
- ).first()
+ try:
+ dataset = db.session.query(Dataset).filter(
+ Dataset.id == dataset_id
+ ).first()
- vector_processor = Vector(
- dataset=dataset,
- )
+ vector_processor = Vector(
+ dataset=dataset,
+ )
- documents = vector_processor.search_by_full_text(
- query,
- top_k=top_k
- )
- if documents:
- if reranking_model and retrival_method == 'full_text_search':
- data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
- all_documents.extend(data_post_processor.invoke(
- query=query,
- documents=documents,
- score_threshold=score_threshold,
- top_n=len(documents)
- ))
- else:
- all_documents.extend(documents)
+ documents = vector_processor.search_by_full_text(
+ query,
+ top_k=top_k
+ )
+ if documents:
+ if reranking_model and retrival_method == 'full_text_search':
+ data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+ all_documents.extend(data_post_processor.invoke(
+ query=query,
+ documents=documents,
+ score_threshold=score_threshold,
+ top_n=len(documents)
+ ))
+ else:
+ all_documents.extend(documents)
+ except Exception as e:
+ exceptions.append(str(e))
diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py
index dc400dafbb..1c16e4d9cd 100644
--- a/api/core/rag/datasource/vdb/field.py
+++ b/api/core/rag/datasource/vdb/field.py
@@ -8,3 +8,4 @@ class Field(Enum):
VECTOR = "vector"
TEXT_KEY = "text"
PRIMARY_KEY = "id"
+ DOC_ID = "metadata.doc_id"
diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py
index c90fe3b188..0586e279d3 100644
--- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py
+++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py
@@ -259,5 +259,5 @@ class MilvusVector(BaseVector):
uri = "https://" + str(config.host) + ":" + str(config.port)
else:
uri = "http://" + str(config.host) + ":" + str(config.port)
- client = MilvusClient(uri=uri, user=config.user, password=config.password)
+ client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database)
return client
diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
index 5735b79b6e..3842aee6c7 100644
--- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
+++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
@@ -54,7 +54,7 @@ class PGVectoRS(BaseVector):
class _Table(CollectionORM):
__tablename__ = collection_name
- __table_args__ = {"extend_existing": True} # noqa: RUF012
+ __table_args__ = {"extend_existing": True}
id: Mapped[UUID] = mapped_column(
postgresql.UUID(as_uuid=True),
primary_key=True,
diff --git a/api/core/rag/datasource/vdb/pgvector/__init__.py b/api/core/rag/datasource/vdb/pgvector/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py
new file mode 100644
index 0000000000..22cf790bfa
--- /dev/null
+++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py
@@ -0,0 +1,169 @@
+import json
+import uuid
+from contextlib import contextmanager
+from typing import Any
+
+import psycopg2.extras
+import psycopg2.pool
+from pydantic import BaseModel, root_validator
+
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+
+
+class PGVectorConfig(BaseModel):
+ host: str
+ port: int
+ user: str
+ password: str
+ database: str
+
+ @root_validator()
+ def validate_config(cls, values: dict) -> dict:
+ if not values["host"]:
+ raise ValueError("config PGVECTOR_HOST is required")
+ if not values["port"]:
+ raise ValueError("config PGVECTOR_PORT is required")
+ if not values["user"]:
+ raise ValueError("config PGVECTOR_USER is required")
+ if not values["password"]:
+ raise ValueError("config PGVECTOR_PASSWORD is required")
+ if not values["database"]:
+ raise ValueError("config PGVECTOR_DATABASE is required")
+ return values
+
+
+SQL_CREATE_TABLE = """
+CREATE TABLE IF NOT EXISTS {table_name} (
+ id UUID PRIMARY KEY,
+ text TEXT NOT NULL,
+ meta JSONB NOT NULL,
+ embedding vector({dimension}) NOT NULL
+) using heap;
+"""
+
+
+class PGVector(BaseVector):
+ def __init__(self, collection_name: str, config: PGVectorConfig):
+ super().__init__(collection_name)
+ self.pool = self._create_connection_pool(config)
+ self.table_name = f"embedding_{collection_name}"
+
+ def get_type(self) -> str:
+ return "pgvector"
+
+ def _create_connection_pool(self, config: PGVectorConfig):
+ return psycopg2.pool.SimpleConnectionPool(
+ 1,
+ 5,
+ host=config.host,
+ port=config.port,
+ user=config.user,
+ password=config.password,
+ database=config.database,
+ )
+
+ @contextmanager
+ def _get_cursor(self):
+ conn = self.pool.getconn()
+ cur = conn.cursor()
+ try:
+ yield cur
+ finally:
+ cur.close()
+ conn.commit()
+ self.pool.putconn(conn)
+
+ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+ dimension = len(embeddings[0])
+ self._create_collection(dimension)
+ return self.add_texts(texts, embeddings)
+
+ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+ values = []
+ pks = []
+ for i, doc in enumerate(documents):
+ doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
+ pks.append(doc_id)
+ values.append(
+ (
+ doc_id,
+ doc.page_content,
+ json.dumps(doc.metadata),
+ embeddings[i],
+ )
+ )
+ with self._get_cursor() as cur:
+ psycopg2.extras.execute_values(
+ cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
+ )
+ return pks
+
+ def text_exists(self, id: str) -> bool:
+ with self._get_cursor() as cur:
+ cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
+ return cur.fetchone() is not None
+
+ def get_by_ids(self, ids: list[str]) -> list[Document]:
+ with self._get_cursor() as cur:
+ cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
+ docs = []
+ for record in cur:
+ docs.append(Document(page_content=record[1], metadata=record[0]))
+ return docs
+
+ def delete_by_ids(self, ids: list[str]) -> None:
+ with self._get_cursor() as cur:
+ cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
+
+ def delete_by_metadata_field(self, key: str, value: str) -> None:
+ with self._get_cursor() as cur:
+ cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
+
+ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+ """
+ Search the nearest neighbors to a vector.
+
+ :param query_vector: The input vector to search for similar items.
+ :param top_k: The number of nearest neighbors to return, default is 5.
+ :return: List of Documents that are nearest to the query vector.
+ """
+ top_k = kwargs.get("top_k", 5)
+
+ with self._get_cursor() as cur:
+ cur.execute(
+ f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name} ORDER BY distance LIMIT {top_k}",
+ (json.dumps(query_vector),),
+ )
+ docs = []
+ score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
+ for record in cur:
+ metadata, text, distance = record
+ score = 1 - distance
+ metadata["score"] = score
+ if score > score_threshold:
+ docs.append(Document(page_content=text, metadata=metadata))
+ return docs
+
+ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+ # do not support bm25 search
+ return []
+
+ def delete(self) -> None:
+ with self._get_cursor() as cur:
+ cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
+
+ def _create_collection(self, dimension: int):
+ cache_key = f"vector_indexing_{self._collection_name}"
+ lock_name = f"{cache_key}_lock"
+ with redis_client.lock(lock_name, timeout=20):
+ collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
+ if redis_client.get(collection_exist_cache_key):
+ return
+
+ with self._get_cursor() as cur:
+ cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
+ cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
+ # TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
index e6e83c66d8..7a92314542 100644
--- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
+++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
@@ -115,9 +115,12 @@ class QdrantVector(BaseVector):
timeout=int(self._client_config.timeout),
)
- # create payload index
+ # create group_id payload index
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
field_schema=PayloadSchemaType.KEYWORD)
+ # create doc_id payload index
+ self._client.create_payload_index(collection_name, Field.DOC_ID.value,
+ field_schema=PayloadSchemaType.KEYWORD)
# creat full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py
index 74b91db27e..ee88d9fa29 100644
--- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py
+++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py
@@ -190,7 +190,7 @@ class RelytVector(BaseVector):
conn.execute(chunks_table.delete().where(delete_condition))
return True
except Exception as e:
- print("Delete operation failed:", str(e)) # noqa: T201
+ print("Delete operation failed:", str(e))
return False
def delete_by_metadata_field(self, key: str, value: str):
diff --git a/api/core/rag/datasource/vdb/tidb_vector/__init__.py b/api/core/rag/datasource/vdb/tidb_vector/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
new file mode 100644
index 0000000000..107d17bb47
--- /dev/null
+++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
@@ -0,0 +1,216 @@
+import json
+import logging
+from typing import Any
+
+import sqlalchemy
+from pydantic import BaseModel, root_validator
+from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
+from sqlalchemy import text as sql_text
+from sqlalchemy.orm import Session, declarative_base
+
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+
+logger = logging.getLogger(__name__)
+
+
+class TiDBVectorConfig(BaseModel):
+ host: str
+ port: int
+ user: str
+ password: str
+ database: str
+
+ @root_validator()
+ def validate_config(cls, values: dict) -> dict:
+ if not values['host']:
+ raise ValueError("config TIDB_VECTOR_HOST is required")
+ if not values['port']:
+ raise ValueError("config TIDB_VECTOR_PORT is required")
+ if not values['user']:
+ raise ValueError("config TIDB_VECTOR_USER is required")
+ if not values['password']:
+ raise ValueError("config TIDB_VECTOR_PASSWORD is required")
+ if not values['database']:
+ raise ValueError("config TIDB_VECTOR_DATABASE is required")
+ return values
+
+
+class TiDBVector(BaseVector):
+
+ def _table(self, dim: int) -> Table:
+ from tidb_vector.sqlalchemy import VectorType
+ return Table(
+ self._collection_name,
+ self._orm_base.metadata,
+ Column('id', String(36), primary_key=True, nullable=False),
+ Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"),
+ Column("text", TEXT, nullable=False),
+ Column("meta", JSON, nullable=False),
+ Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
+ Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
+ extend_existing=True
+ )
+
+ def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'):
+ super().__init__(collection_name)
+ self._client_config = config
+ self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
+ f"ssl_verify_cert=true&ssl_verify_identity=true")
+ self._distance_func = distance_func.lower()
+ self._engine = create_engine(self._url)
+ self._orm_base = declarative_base()
+ self._dimension = 1536
+
+ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+ logger.info("create collection and add texts, collection_name: " + self._collection_name)
+ self._create_collection(len(embeddings[0]))
+ self.add_texts(texts, embeddings)
+ self._dimension = len(embeddings[0])
+ pass
+
+ def _create_collection(self, dimension: int):
+ logger.info("_create_collection, collection_name " + self._collection_name)
+ lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
+ with redis_client.lock(lock_name, timeout=20):
+ collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
+ if redis_client.get(collection_exist_cache_key):
+ return
+ with Session(self._engine) as session:
+ session.begin()
+ create_statement = sql_text(f"""
+ CREATE TABLE IF NOT EXISTS {self._collection_name} (
+ id CHAR(36) PRIMARY KEY,
+ text TEXT NOT NULL,
+ meta JSON NOT NULL,
+ doc_id VARCHAR(64) AS (JSON_UNQUOTE(JSON_EXTRACT(meta, '$.doc_id'))) STORED,
+ KEY (doc_id),
+ vector VECTOR({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
+ create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
+ update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
+ );
+ """)
+ session.execute(create_statement)
+ # tidb vector not support 'CREATE/ADD INDEX' now
+ session.commit()
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+ table = self._table(len(embeddings[0]))
+ ids = self._get_uuids(documents)
+ metas = [d.metadata for d in documents]
+ texts = [d.page_content for d in documents]
+
+ chunks_table_data = []
+ with self._engine.connect() as conn:
+ with conn.begin():
+ for id, text, meta, embedding in zip(
+ ids, texts, metas, embeddings
+ ):
+ chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
+
+ # Execute the batch insert when the batch size is reached
+ if len(chunks_table_data) == 500:
+ conn.execute(insert(table).values(chunks_table_data))
+ # Clear the chunks_table_data list for the next batch
+ chunks_table_data.clear()
+
+ # Insert any remaining records that didn't make up a full batch
+ if chunks_table_data:
+ conn.execute(insert(table).values(chunks_table_data))
+ return ids
+
+ def text_exists(self, id: str) -> bool:
+ result = self.get_ids_by_metadata_field('doc_id', id)
+ return bool(result)
+
+ def delete_by_ids(self, ids: list[str]) -> None:
+ with Session(self._engine) as session:
+ ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
+ select_statement = sql_text(
+ f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
+ )
+ result = session.execute(select_statement).fetchall()
+ if result:
+ ids = [item[0] for item in result]
+ self._delete_by_ids(ids)
+
+ def _delete_by_ids(self, ids: list[str]) -> bool:
+ if ids is None:
+ raise ValueError("No ids provided to delete.")
+ table = self._table(self._dimension)
+ try:
+ with self._engine.connect() as conn:
+ with conn.begin():
+ delete_condition = table.c.id.in_(ids)
+ conn.execute(table.delete().where(delete_condition))
+ return True
+ except Exception as e:
+ print("Delete operation failed:", str(e))
+ return False
+
+ def delete_by_document_id(self, document_id: str):
+ ids = self.get_ids_by_metadata_field('document_id', document_id)
+ if ids:
+ self._delete_by_ids(ids)
+
+ def get_ids_by_metadata_field(self, key: str, value: str):
+ with Session(self._engine) as session:
+ select_statement = sql_text(
+ f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.{key}' = '{value}'; """
+ )
+ result = session.execute(select_statement).fetchall()
+ if result:
+ return [item[0] for item in result]
+ else:
+ return None
+
+ def delete_by_metadata_field(self, key: str, value: str) -> None:
+ ids = self.get_ids_by_metadata_field(key, value)
+ if ids:
+ self._delete_by_ids(ids)
+
+ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+ top_k = kwargs.get("top_k", 5)
+ score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
+ filter = kwargs.get('filter')
+ distance = 1 - score_threshold
+
+ query_vector_str = ", ".join(format(x) for x in query_vector)
+ query_vector_str = "[" + query_vector_str + "]"
+ logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}")
+
+ docs = []
+ if self._distance_func == 'l2':
+ tidb_func = 'Vec_l2_distance'
+ elif self._distance_func == 'cosine':
+ tidb_func = 'Vec_Cosine_distance'
+ else:
+ tidb_func = 'Vec_Cosine_distance'
+
+ with Session(self._engine) as session:
+ select_statement = sql_text(
+ f"""SELECT meta, text, distance FROM (
+ SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
+ FROM {self._collection_name}
+ ORDER BY distance
+ LIMIT {top_k}
+ ) t WHERE distance < {distance};"""
+ )
+ res = session.execute(select_statement)
+ results = [(row[0], row[1], row[2]) for row in res]
+ for meta, text, distance in results:
+ metadata = json.loads(meta)
+ metadata['score'] = 1 - distance
+ docs.append(Document(page_content=text, metadata=metadata))
+ return docs
+
+ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+ # tidb doesn't support bm25 search
+ return []
+
+ def delete(self) -> None:
+ with Session(self._engine) as session:
+ session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
+ session.commit()
diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py
index 2405d16b1d..b500b37d60 100644
--- a/api/core/rag/datasource/vdb/vector_factory.py
+++ b/api/core/rag/datasource/vdb/vector_factory.py
@@ -164,6 +164,54 @@ class Vector:
),
dim=dim
)
+ elif vector_type == "pgvector":
+ from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
+
+ if self._dataset.index_struct_dict:
+ class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"]
+ collection_name = class_prefix
+ else:
+ dataset_id = self._dataset.id
+ collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+ index_struct_dict = {
+ "type": "pgvector",
+ "vector_store": {"class_prefix": collection_name}}
+ self._dataset.index_struct = json.dumps(index_struct_dict)
+ return PGVector(
+ collection_name=collection_name,
+ config=PGVectorConfig(
+ host=config.get("PGVECTOR_HOST"),
+ port=config.get("PGVECTOR_PORT"),
+ user=config.get("PGVECTOR_USER"),
+ password=config.get("PGVECTOR_PASSWORD"),
+ database=config.get("PGVECTOR_DATABASE"),
+ ),
+ )
+ elif vector_type == "tidb_vector":
+ from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
+
+ if self._dataset.index_struct_dict:
+ class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
+ collection_name = class_prefix.lower()
+ else:
+ dataset_id = self._dataset.id
+ collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
+ index_struct_dict = {
+ "type": 'tidb_vector',
+ "vector_store": {"class_prefix": collection_name}
+ }
+ self._dataset.index_struct = json.dumps(index_struct_dict)
+
+ return TiDBVector(
+ collection_name=collection_name,
+ config=TiDBVectorConfig(
+ host=config.get('TIDB_VECTOR_HOST'),
+ port=config.get('TIDB_VECTOR_PORT'),
+ user=config.get('TIDB_VECTOR_USER'),
+ password=config.get('TIDB_VECTOR_PASSWORD'),
+ database=config.get('TIDB_VECTOR_DATABASE'),
+ ),
+ )
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
diff --git a/api/core/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py
similarity index 91%
rename from api/core/docstore/dataset_docstore.py
rename to api/core/rag/docstore/dataset_docstore.py
index 7567493b9f..96a15be742 100644
--- a/api/core/docstore/dataset_docstore.py
+++ b/api/core/rag/docstore/dataset_docstore.py
@@ -1,11 +1,10 @@
from collections.abc import Sequence
-from typing import Any, Optional, cast
+from typing import Any, Optional
from sqlalchemy import func
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
@@ -95,11 +94,7 @@ class DatasetDocumentStore:
# calc embedding use tokens
if embedding_model:
- model_type_instance = embedding_model.model_type_instance
- model_type_instance = cast(TextEmbeddingModel, model_type_instance)
- tokens = model_type_instance.get_num_tokens(
- model=embedding_model.model,
- credentials=embedding_model.credentials,
+ tokens = embedding_model.get_text_embedding_num_tokens(
texts=[doc.page_content]
)
else:
@@ -121,13 +116,13 @@ class DatasetDocumentStore:
enabled=False,
created_by=self._user_id,
)
- if 'answer' in doc.metadata and doc.metadata['answer']:
+ if doc.metadata.get('answer'):
segment_document.answer = doc.metadata.pop('answer', '')
db.session.add(segment_document)
else:
segment_document.content = doc.page_content
- if 'answer' in doc.metadata and doc.metadata['answer']:
+ if doc.metadata.get('answer'):
segment_document.answer = doc.metadata.pop('answer', '')
segment_document.index_node_hash = doc.metadata['doc_hash']
segment_document.word_count = len(doc.page_content)
diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py
index 09a1cddd1e..0470569f39 100644
--- a/api/core/rag/extractor/csv_extractor.py
+++ b/api/core/rag/extractor/csv_extractor.py
@@ -57,7 +57,7 @@ class CSVExtractor(BaseExtractor):
docs = []
try:
# load csv file into pandas dataframe
- df = pd.read_csv(csvfile, error_bad_lines=False, **self.csv_args)
+ df = pd.read_csv(csvfile, on_bad_lines='skip', **self.csv_args)
# check source column exists
if self.source_column and self.source_column not in df.columns:
diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py
index 2b0066448e..4d2f61139a 100644
--- a/api/core/rag/extractor/excel_extractor.py
+++ b/api/core/rag/extractor/excel_extractor.py
@@ -39,8 +39,8 @@ class ExcelExtractor(BaseExtractor):
documents = []
# loop over all sheets
for sheet in wb.sheets():
- for row_index, row in enumerate(sheet.get_rows(), start=1):
- row_header = None
+ row_header = None
+ for row_index, row in enumerate(sheet.get_rows(), start=1):
if self.is_blank_row(row):
continue
if row_header is None:
@@ -49,8 +49,8 @@ class ExcelExtractor(BaseExtractor):
item_arr = []
for index, cell in enumerate(row):
txt_value = str(cell.value)
- item_arr.append(f'{row_header[index].value}:{txt_value}')
- item_str = "\n".join(item_arr)
+ item_arr.append(f'"{row_header[index].value}":"{txt_value}"')
+ item_str = ",".join(item_arr)
document = Document(page_content=item_str, metadata={'source': self._file_path})
documents.append(document)
return documents
@@ -68,7 +68,7 @@ class ExcelExtractor(BaseExtractor):
# transform each row into a Document
for _, row in df.iterrows():
- item = ';'.join(f'{k}:{v}' for k, v in row.items() if pd.notna(v))
+ item = ';'.join(f'"{k}":"{v}"' for k, v in row.items() if pd.notna(v))
document = Document(page_content=item, metadata={'source': self._file_path})
data.append(document)
return data
diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py
index 1136e11f76..09d192d410 100644
--- a/api/core/rag/extractor/extract_processor.py
+++ b/api/core/rag/extractor/extract_processor.py
@@ -1,6 +1,8 @@
+import re
import tempfile
from pathlib import Path
from typing import Union
+from urllib.parse import unquote
import requests
from flask import current_app
@@ -14,7 +16,6 @@ from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor
from core.rag.extractor.text_extractor import TextExtractor
-from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
@@ -28,7 +29,7 @@ from core.rag.models.document import Document
from extensions.ext_storage import storage
from models.model import UploadFile
-SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
+SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain', 'application/json']
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
@@ -55,6 +56,17 @@ class ExtractProcessor:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
+ if not suffix and suffix != '.':
+ # get content-type
+ if response.headers.get('Content-Type'):
+ suffix = '.' + response.headers.get('Content-Type').split('/')[-1]
+ else:
+ content_disposition = response.headers.get('Content-Disposition')
+ filename_match = re.search(r'filename="([^"]+)"', content_disposition)
+ if filename_match:
+ filename = unquote(filename_match.group(1))
+ suffix = '.' + re.search(r'\.(\w+)$', filename).group(1)
+
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
file.write(response.content)
@@ -83,6 +95,7 @@ class ExtractProcessor:
file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
+ unstructured_api_key = current_app.config['UNSTRUCTURED_API_KEY']
if etl_type == 'Unstructured':
if file_extension == '.xlsx' or file_extension == '.xls':
extractor = ExcelExtractor(file_path)
@@ -94,7 +107,7 @@ class ExtractProcessor:
elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
- extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
+ extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
@@ -102,7 +115,7 @@ class ExtractProcessor:
elif file_extension == '.eml':
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
elif file_extension == '.ppt':
- extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url)
+ extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key)
elif file_extension == '.pptx':
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
elif file_extension == '.xml':
@@ -123,7 +136,7 @@ class ExtractProcessor:
elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
- extractor = WordExtractor(file_path)
+ extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == 'epub':
diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py
index c40064fd1d..1885ad3aca 100644
--- a/api/core/rag/extractor/notion_extractor.py
+++ b/api/core/rag/extractor/notion_extractor.py
@@ -19,8 +19,12 @@ SEARCH_URL = "https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
-HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
-
+# if user want split by headings, use the corresponding splitter
+HEADING_SPLITTER = {
+ 'heading_1': '# ',
+ 'heading_2': '## ',
+ 'heading_3': '### ',
+}
class NotionExtractor(BaseExtractor):
@@ -73,8 +77,7 @@ class NotionExtractor(BaseExtractor):
docs.extend(page_text_documents)
elif notion_page_type == 'page':
page_text_list = self._get_notion_block_data(notion_obj_id)
- for page_text in page_text_list:
- docs.append(Document(page_content=page_text))
+ docs.append(Document(page_content='\n'.join(page_text_list)))
else:
raise ValueError("notion page type not supported")
@@ -96,7 +99,7 @@ class NotionExtractor(BaseExtractor):
data = res.json()
- database_content_list = []
+ database_content = []
if 'results' not in data or data["results"] is None:
return []
for result in data["results"]:
@@ -131,10 +134,9 @@ class NotionExtractor(BaseExtractor):
row_content = row_content + f'{key}:{value_content}\n'
else:
row_content = row_content + f'{key}:{value}\n'
- document = Document(page_content=row_content)
- database_content_list.append(document)
+ database_content.append(row_content)
- return database_content_list
+ return [Document(page_content='\n'.join(database_content))]
def _get_notion_block_data(self, page_id: str) -> list[str]:
result_lines_arr = []
@@ -154,8 +156,6 @@ class NotionExtractor(BaseExtractor):
json=query_dict
)
data = res.json()
- # current block's heading
- heading = ''
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
@@ -172,8 +172,6 @@ class NotionExtractor(BaseExtractor):
if "text" in rich_text:
text = rich_text["text"]["content"]
cur_result_text_arr.append(text)
- if result_type in HEADING_TYPE:
- heading = text
result_block_id = result["id"]
has_children = result["has_children"]
@@ -185,11 +183,10 @@ class NotionExtractor(BaseExtractor):
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
- cur_result_text += "\n\n"
- if result_type in HEADING_TYPE:
- result_lines_arr.append(cur_result_text)
+ if result_type in HEADING_SPLITTER:
+ result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}")
else:
- result_lines_arr.append(f'{heading}\n{cur_result_text}')
+ result_lines_arr.append(cur_result_text + '\n\n')
if data["next_cursor"] is None:
break
@@ -218,7 +215,6 @@ class NotionExtractor(BaseExtractor):
data = res.json()
if 'results' not in data or data["results"] is None:
break
- heading = ''
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
@@ -235,8 +231,6 @@ class NotionExtractor(BaseExtractor):
text = rich_text["text"]["content"]
prefix = "\t" * num_tabs
cur_result_text_arr.append(prefix + text)
- if result_type in HEADING_TYPE:
- heading = text
result_block_id = result["id"]
has_children = result["has_children"]
block_type = result["type"]
@@ -247,10 +241,10 @@ class NotionExtractor(BaseExtractor):
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
- if result_type in HEADING_TYPE:
- result_lines_arr.append(cur_result_text)
+ if result_type in HEADING_SPLITTER:
+ result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}')
else:
- result_lines_arr.append(f'{heading}\n{cur_result_text}')
+ result_lines_arr.append(cur_result_text + '\n\n')
if data["next_cursor"] is None:
break
diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py
index 6d3ffe6589..d354b593ed 100644
--- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py
+++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py
@@ -17,16 +17,18 @@ class UnstructuredPPTExtractor(BaseExtractor):
def __init__(
self,
file_path: str,
- api_url: str
+ api_url: str,
+ api_key: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
+ self._api_key = api_key
def extract(self) -> list[Document]:
from unstructured.partition.api import partition_via_api
- elements = partition_via_api(filename=self._file_path, api_url=self._api_url)
+ elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
text_by_page = {}
for element in elements:
page = element.metadata.page_number
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index 6bf47a76f0..5b858c6c4c 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -1,12 +1,20 @@
"""Abstract interface for document loader implementations."""
+import datetime
+import mimetypes
import os
import tempfile
+import uuid
from urllib.parse import urlparse
import requests
+from docx import Document as DocxDocument
+from flask import current_app
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
+from extensions.ext_database import db
+from extensions.ext_storage import storage
+from models.model import UploadFile
class WordExtractor(BaseExtractor):
@@ -17,9 +25,12 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
- def __init__(self, file_path: str):
+ def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path."""
self.file_path = file_path
+ self.tenant_id = tenant_id
+ self.user_id = user_id
+
if "~" in self.file_path:
self.file_path = os.path.expanduser(self.file_path)
@@ -45,12 +56,7 @@ class WordExtractor(BaseExtractor):
def extract(self) -> list[Document]:
"""Load given path as single page."""
- from docx import Document as docx_Document
-
- document = docx_Document(self.file_path)
- doc_texts = [paragraph.text for paragraph in document.paragraphs]
- content = '\n'.join(doc_texts)
-
+ content = self.parse_docx(self.file_path, 'storage')
return [Document(
page_content=content,
metadata={"source": self.file_path},
@@ -61,3 +67,111 @@ class WordExtractor(BaseExtractor):
"""Check if the url is valid."""
parsed = urlparse(url)
return bool(parsed.netloc) and bool(parsed.scheme)
+
+ def _extract_images_from_docx(self, doc, image_folder):
+ os.makedirs(image_folder, exist_ok=True)
+ image_count = 0
+ image_map = {}
+
+ for rel in doc.part.rels.values():
+ if "image" in rel.target_ref:
+ image_count += 1
+ image_ext = rel.target_ref.split('.')[-1]
+ # user uuid as file name
+ file_uuid = str(uuid.uuid4())
+ file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext
+ mime_type, _ = mimetypes.guess_type(file_key)
+
+ storage.save(file_key, rel.target_part.blob)
+ # save file to db
+ config = current_app.config
+ upload_file = UploadFile(
+ tenant_id=self.tenant_id,
+ storage_type=config['STORAGE_TYPE'],
+ key=file_key,
+ name=file_key,
+ size=0,
+ extension=image_ext,
+ mime_type=mime_type,
+ created_by=self.user_id,
+ created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
+ used=True,
+ used_by=self.user_id,
+ used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+ )
+
+ db.session.add(upload_file)
+ db.session.commit()
+ image_map[rel.target_part] = f"}/files/{upload_file.id}/image-preview)"
+
+ return image_map
+
+ def _table_to_markdown(self, table):
+ markdown = ""
+ # deal with table headers
+ header_row = table.rows[0]
+ headers = [cell.text for cell in header_row.cells]
+ markdown += "| " + " | ".join(headers) + " |\n"
+ markdown += "| " + " | ".join(["---"] * len(headers)) + " |\n"
+ # deal with table rows
+ for row in table.rows[1:]:
+ row_cells = [cell.text for cell in row.cells]
+ markdown += "| " + " | ".join(row_cells) + " |\n"
+
+ return markdown
+
+ def _parse_paragraph(self, paragraph, image_map):
+ paragraph_content = []
+ for run in paragraph.runs:
+ if run.element.xpath('.//a:blip'):
+ for blip in run.element.xpath('.//a:blip'):
+ embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
+ if embed_id:
+ rel_target = run.part.rels[embed_id].target_ref
+ if rel_target in image_map:
+ paragraph_content.append(image_map[rel_target])
+ if run.text.strip():
+ paragraph_content.append(run.text.strip())
+ return ' '.join(paragraph_content) if paragraph_content else ''
+
+ def parse_docx(self, docx_path, image_folder):
+ doc = DocxDocument(docx_path)
+ os.makedirs(image_folder, exist_ok=True)
+
+ content = []
+
+ image_map = self._extract_images_from_docx(doc, image_folder)
+
+ def parse_paragraph(paragraph):
+ paragraph_content = []
+ for run in paragraph.runs:
+ if run.element.tag.endswith('r'):
+ drawing_elements = run.element.findall(
+ './/{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing')
+ for drawing in drawing_elements:
+ blip_elements = drawing.findall(
+ './/{http://schemas.openxmlformats.org/drawingml/2006/main}blip')
+ for blip in blip_elements:
+ embed_id = blip.get(
+ '{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
+ if embed_id:
+ image_part = doc.part.related_parts.get(embed_id)
+ if image_part in image_map:
+ paragraph_content.append(image_map[image_part])
+ if run.text.strip():
+ paragraph_content.append(run.text.strip())
+ return ''.join(paragraph_content) if paragraph_content else ''
+
+ paragraphs = doc.paragraphs.copy()
+ tables = doc.tables.copy()
+ for element in doc.element.body:
+ if element.tag.endswith('p'): # paragraph
+ para = paragraphs.pop(0)
+ parsed_paragraph = parse_paragraph(para)
+ if parsed_paragraph:
+ content.append(parsed_paragraph)
+ elif element.tag.endswith('tbl'): # table
+ table = tables.pop(0)
+ content.append(self._table_to_markdown(table))
+ return '\n'.join(content)
+
diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py
index 509a1a189b..edc16c821a 100644
--- a/api/core/rag/index_processor/index_processor_base.py
+++ b/api/core/rag/index_processor/index_processor_base.py
@@ -2,11 +2,16 @@
from abc import ABC, abstractmethod
from typing import Optional
+from flask import current_app
+
from core.model_manager import ModelInstance
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
-from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
-from core.splitter.text_splitter import TextSplitter
+from core.rag.splitter.fixed_text_splitter import (
+ EnhanceRecursiveCharacterTextSplitter,
+ FixedRecursiveCharacterTextSplitter,
+)
+from core.rag.splitter.text_splitter import TextSplitter
from models.dataset import Dataset, DatasetProcessRule
@@ -43,8 +48,9 @@ class BaseIndexProcessor(ABC):
# The user-defined segmentation rule
rules = processing_rule['rules']
segmentation = rules["segmentation"]
- if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
- raise ValueError("Custom segment length should be between 50 and 1000.")
+ max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH'])
+ if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
+ raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
separator = segmentation["separator"]
if separator:
@@ -54,7 +60,7 @@ class BaseIndexProcessor(ABC):
chunk_size=segmentation["max_tokens"],
chunk_overlap=segmentation.get('chunk_overlap', 0),
fixed_separator=separator,
- separators=["\n\n", "。", ".", " ", ""],
+ separators=["\n\n", "。", ". ", " ", ""],
embedding_model_instance=embedding_model_instance
)
else:
@@ -62,7 +68,7 @@ class BaseIndexProcessor(ABC):
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
- separators=["\n\n", "。", ".", " ", ""],
+ separators=["\n\n", "。", ". ", " ", ""],
embedding_model_instance=embedding_model_instance
)
diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py
index 221318c2c3..7bb675b149 100644
--- a/api/core/rag/models/document.py
+++ b/api/core/rag/models/document.py
@@ -50,7 +50,7 @@ class BaseDocumentTransformer(ABC):
) -> Sequence[Document]:
raise NotImplementedError
- """ # noqa: E501
+ """
@abstractmethod
def transform_documents(
diff --git a/api/core/rag/rerank/__init__.py b/api/core/rag/rerank/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/rerank/rerank.py b/api/core/rag/rerank/rerank.py
similarity index 100%
rename from api/core/rerank/rerank.py
rename to api/core/rag/rerank/rerank.py
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index 155b8be06c..b42a441a3f 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -14,9 +14,9 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
+from core.rag.rerank.rerank import RerankRunner
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
-from core.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
@@ -124,7 +124,7 @@ class DatasetRetrieval:
document_score_list = {}
for item in all_documents:
- if 'score' in item.metadata and item.metadata['score']:
+ if item.metadata.get('score'):
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
@@ -144,9 +144,9 @@ class DatasetRetrieval:
float('inf')))
for segment in sorted_segments:
if segment.answer:
- document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
+ document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
else:
- document_context_list.append(segment.content)
+ document_context_list.append(segment.get_sign_content())
if show_retrieve_source:
context_list = []
resource_number = 1
@@ -329,6 +329,7 @@ class DatasetRetrieval:
"""
if not query:
return
+ dataset_queries = []
for dataset_id in dataset_ids:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
@@ -338,7 +339,9 @@ class DatasetRetrieval:
created_by_role=user_from,
created_by=user_id
)
- db.session.add(dataset_query)
+ dataset_queries.append(dataset_query)
+ if dataset_queries:
+ db.session.add_all(dataset_queries)
db.session.commit()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
diff --git a/api/core/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py
similarity index 86%
rename from api/core/splitter/fixed_text_splitter.py
rename to api/core/rag/splitter/fixed_text_splitter.py
index a1510259ac..fd714edf5e 100644
--- a/api/core/splitter/fixed_text_splitter.py
+++ b/api/core/rag/splitter/fixed_text_splitter.py
@@ -1,12 +1,11 @@
"""Functionality for splitting text."""
from __future__ import annotations
-from typing import Any, Optional, cast
+from typing import Any, Optional
from core.model_manager import ModelInstance
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
-from core.splitter.text_splitter import (
+from core.rag.splitter.text_splitter import (
TS,
Collection,
Literal,
@@ -35,11 +34,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
return 0
if embedding_model_instance:
- embedding_model_type_instance = embedding_model_instance.model_type_instance
- embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
- return embedding_model_type_instance.get_num_tokens(
- model=embedding_model_instance.model,
- credentials=embedding_model_instance.credentials,
+ return embedding_model_instance.get_text_embedding_num_tokens(
texts=[text]
)
else:
diff --git a/api/core/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py
similarity index 64%
rename from api/core/splitter/text_splitter.py
rename to api/core/rag/splitter/text_splitter.py
index 5eeb237a96..b3adcedc76 100644
--- a/api/core/splitter/text_splitter.py
+++ b/api/core/rag/splitter/text_splitter.py
@@ -6,7 +6,6 @@ import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable, Sequence, Set
from dataclasses import dataclass
-from enum import Enum
from typing import (
Any,
Literal,
@@ -94,7 +93,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
documents.append(new_doc)
return documents
- def split_documents(self, documents: Iterable[Document] ) -> list[Document]:
+ def split_documents(self, documents: Iterable[Document]) -> list[Document]:
"""Split documents."""
texts, metadatas = [], []
for doc in documents:
@@ -477,27 +476,6 @@ class TokenTextSplitter(TextSplitter):
return split_text_on_tokens(text=text, tokenizer=tokenizer)
-class Language(str, Enum):
- """Enum of the programming languages."""
-
- CPP = "cpp"
- GO = "go"
- JAVA = "java"
- JS = "js"
- PHP = "php"
- PROTO = "proto"
- PYTHON = "python"
- RST = "rst"
- RUBY = "ruby"
- RUST = "rust"
- SCALA = "scala"
- SWIFT = "swift"
- MARKDOWN = "markdown"
- LATEX = "latex"
- HTML = "html"
- SOL = "sol"
-
-
class RecursiveCharacterTextSplitter(TextSplitter):
"""Splitting text by recursively look at characters.
@@ -554,350 +532,3 @@ class RecursiveCharacterTextSplitter(TextSplitter):
def split_text(self, text: str) -> list[str]:
return self._split_text(text, self._separators)
-
- @classmethod
- def from_language(
- cls, language: Language, **kwargs: Any
- ) -> RecursiveCharacterTextSplitter:
- separators = cls.get_separators_for_language(language)
- return cls(separators=separators, **kwargs)
-
- @staticmethod
- def get_separators_for_language(language: Language) -> list[str]:
- if language == Language.CPP:
- return [
- # Split along class definitions
- "\nclass ",
- # Split along function definitions
- "\nvoid ",
- "\nint ",
- "\nfloat ",
- "\ndouble ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\nswitch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.GO:
- return [
- # Split along function definitions
- "\nfunc ",
- "\nvar ",
- "\nconst ",
- "\ntype ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nswitch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.JAVA:
- return [
- # Split along class definitions
- "\nclass ",
- # Split along method definitions
- "\npublic ",
- "\nprotected ",
- "\nprivate ",
- "\nstatic ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\nswitch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.JS:
- return [
- # Split along function definitions
- "\nfunction ",
- "\nconst ",
- "\nlet ",
- "\nvar ",
- "\nclass ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\nswitch ",
- "\ncase ",
- "\ndefault ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.PHP:
- return [
- # Split along function definitions
- "\nfunction ",
- # Split along class definitions
- "\nclass ",
- # Split along control flow statements
- "\nif ",
- "\nforeach ",
- "\nwhile ",
- "\ndo ",
- "\nswitch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.PROTO:
- return [
- # Split along message definitions
- "\nmessage ",
- # Split along service definitions
- "\nservice ",
- # Split along enum definitions
- "\nenum ",
- # Split along option definitions
- "\noption ",
- # Split along import statements
- "\nimport ",
- # Split along syntax declarations
- "\nsyntax ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.PYTHON:
- return [
- # First, try to split along class definitions
- "\nclass ",
- "\ndef ",
- "\n\tdef ",
- # Now split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.RST:
- return [
- # Split along section titles
- "\n=+\n",
- "\n-+\n",
- "\n\*+\n",
- # Split along directive markers
- "\n\n.. *\n\n",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.RUBY:
- return [
- # Split along method definitions
- "\ndef ",
- "\nclass ",
- # Split along control flow statements
- "\nif ",
- "\nunless ",
- "\nwhile ",
- "\nfor ",
- "\ndo ",
- "\nbegin ",
- "\nrescue ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.RUST:
- return [
- # Split along function definitions
- "\nfn ",
- "\nconst ",
- "\nlet ",
- # Split along control flow statements
- "\nif ",
- "\nwhile ",
- "\nfor ",
- "\nloop ",
- "\nmatch ",
- "\nconst ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.SCALA:
- return [
- # Split along class definitions
- "\nclass ",
- "\nobject ",
- # Split along method definitions
- "\ndef ",
- "\nval ",
- "\nvar ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\nmatch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.SWIFT:
- return [
- # Split along function definitions
- "\nfunc ",
- # Split along class definitions
- "\nclass ",
- "\nstruct ",
- "\nenum ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\ndo ",
- "\nswitch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.MARKDOWN:
- return [
- # First, try to split along Markdown headings (starting with level 2)
- "\n#{1,6} ",
- # Note the alternative syntax for headings (below) is not handled here
- # Heading level 2
- # ---------------
- # End of code block
- "```\n",
- # Horizontal lines
- "\n\*\*\*+\n",
- "\n---+\n",
- "\n___+\n",
- # Note that this splitter doesn't handle horizontal lines defined
- # by *three or more* of ***, ---, or ___, but this is not handled
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.LATEX:
- return [
- # First, try to split along Latex sections
- "\n\\\chapter{",
- "\n\\\section{",
- "\n\\\subsection{",
- "\n\\\subsubsection{",
- # Now split by environments
- "\n\\\begin{enumerate}",
- "\n\\\begin{itemize}",
- "\n\\\begin{description}",
- "\n\\\begin{list}",
- "\n\\\begin{quote}",
- "\n\\\begin{quotation}",
- "\n\\\begin{verse}",
- "\n\\\begin{verbatim}",
- # Now split by math environments
- "\n\\\begin{align}",
- "$$",
- "$",
- # Now split by the normal type of lines
- " ",
- "",
- ]
- elif language == Language.HTML:
- return [
- # First, try to split along HTML tags
- " dict:
+ # -------------
+ # overwrite tool parameter types for temp fix
+ tools = jsonable_encoder(self.tools)
+ for tool in tools:
+ if tool.get('parameters'):
+ for parameter in tool.get('parameters'):
+ if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value:
+ parameter['type'] = 'files'
+ # -------------
+
return {
'id': self.id,
'author': self.author,
@@ -47,7 +57,8 @@ class UserToolProvider(BaseModel):
'team_credentials': self.masked_credentials,
'is_team_authorization': self.is_team_authorization,
'allow_delete': self.allow_delete,
- 'tools': self.tools
+ 'tools': tools,
+ 'labels': self.labels,
}
class UserToolProviderCredentials(BaseModel):
diff --git a/api/core/tools/entities/constant.py b/api/core/tools/entities/constant.py
deleted file mode 100644
index 2e75fedf99..0000000000
--- a/api/core/tools/entities/constant.py
+++ /dev/null
@@ -1,3 +0,0 @@
-class DEFAULT_PROVIDERS:
- API_BASED = '__api_based'
- APP_BASED = '__app_based'
\ No newline at end of file
diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py
index efa10e792c..d18d27fb02 100644
--- a/api/core/tools/entities/tool_bundle.py
+++ b/api/core/tools/entities/tool_bundle.py
@@ -1,11 +1,11 @@
-from typing import Any, Optional
+from typing import Optional
from pydantic import BaseModel
-from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
+from core.tools.entities.tool_entities import ToolParameter
-class ApiBasedToolBundle(BaseModel):
+class ApiToolBundle(BaseModel):
"""
This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc.
"""
@@ -25,12 +25,3 @@ class ApiBasedToolBundle(BaseModel):
icon: Optional[str] = None
# openapi operation
openapi: dict
-
-class AppToolBundle(BaseModel):
- """
- This class is used to store the schema information of an tool for an app.
- """
- type: ToolProviderType
- credential: Optional[dict[str, Any]] = None
- provider_id: str
- tool_name: str
\ No newline at end of file
diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py
index fad91baf83..55ef8e8291 100644
--- a/api/core/tools/entities/tool_entities.py
+++ b/api/core/tools/entities/tool_entities.py
@@ -6,14 +6,33 @@ from pydantic import BaseModel, Field
from core.tools.entities.common_entities import I18nObject
+class ToolLabelEnum(Enum):
+ SEARCH = 'search'
+ IMAGE = 'image'
+ VIDEOS = 'videos'
+ WEATHER = 'weather'
+ FINANCE = 'finance'
+ DESIGN = 'design'
+ TRAVEL = 'travel'
+ SOCIAL = 'social'
+ NEWS = 'news'
+ MEDICAL = 'medical'
+ PRODUCTIVITY = 'productivity'
+ EDUCATION = 'education'
+ BUSINESS = 'business'
+ ENTERTAINMENT = 'entertainment'
+ UTILITIES = 'utilities'
+ OTHER = 'other'
+
class ToolProviderType(Enum):
"""
Enum class for tool provider
"""
- BUILT_IN = "built-in"
+ BUILT_IN = "builtin"
+ WORKFLOW = "workflow"
+ API = "api"
+ APP = "app"
DATASET_RETRIEVAL = "dataset-retrieval"
- APP_BASED = "app-based"
- API_BASED = "api-based"
@classmethod
def value_of(cls, value: str) -> 'ToolProviderType':
@@ -77,6 +96,7 @@ class ToolInvokeMessage(BaseModel):
LINK = "link"
BLOB = "blob"
IMAGE_LINK = "image_link"
+ FILE_VAR = "file_var"
type: MessageType = MessageType.TEXT
"""
@@ -90,18 +110,21 @@ class ToolInvokeMessageBinary(BaseModel):
mimetype: str = Field(..., description="The mimetype of the binary")
url: str = Field(..., description="The url of the binary")
save_as: str = ''
+ file_var: Optional[dict[str, Any]] = None
class ToolParameterOption(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
+
class ToolParameter(BaseModel):
- class ToolParameterType(Enum):
+ class ToolParameterType(str, Enum):
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
SELECT = "select"
SECRET_INPUT = "secret-input"
+ FILE = "file"
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
@@ -153,6 +176,7 @@ class ToolProviderIdentity(BaseModel):
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
+ tags: Optional[list[ToolLabelEnum]] = Field(default=[], description="The tags of the tool", )
class ToolDescription(BaseModel):
human: I18nObject = Field(..., description="The description presented to the user")
@@ -331,6 +355,15 @@ class ModelToolProviderConfiguration(BaseModel):
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
label: I18nObject = Field(..., description="The label of the model tool")
+
+class WorkflowToolParameterConfiguration(BaseModel):
+ """
+ Workflow tool configuration
+ """
+ name: str = Field(..., description="The name of the parameter")
+ description: str = Field(..., description="The description of the parameter")
+ form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
+
class ToolInvokeMeta(BaseModel):
"""
Tool invoke meta
@@ -358,4 +391,19 @@ class ToolInvokeMeta(BaseModel):
'time_cost': self.time_cost,
'error': self.error,
'tool_config': self.tool_config,
- }
\ No newline at end of file
+ }
+
+class ToolLabel(BaseModel):
+ """
+ Tool label
+ """
+ name: str = Field(..., description="The name of the tool")
+ label: I18nObject = Field(..., description="The label of the tool")
+ icon: str = Field(..., description="The icon of the tool")
+
+class ToolInvokeFrom(Enum):
+ """
+ Enum class for tool invoke
+ """
+ WORKFLOW = "workflow"
+ AGENT = "agent"
\ No newline at end of file
diff --git a/api/core/tools/entities/values.py b/api/core/tools/entities/values.py
new file mode 100644
index 0000000000..d0be5e9355
--- /dev/null
+++ b/api/core/tools/entities/values.py
@@ -0,0 +1,75 @@
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum
+
+ICONS = {
+ ToolLabelEnum.SEARCH: '''''',
+ ToolLabelEnum.IMAGE: '''''',
+ ToolLabelEnum.VIDEOS: '''''',
+ ToolLabelEnum.WEATHER: '''''',
+ ToolLabelEnum.FINANCE: '''''',
+ ToolLabelEnum.DESIGN: '''''',
+ ToolLabelEnum.TRAVEL: '''''',
+ ToolLabelEnum.SOCIAL: '''''',
+ ToolLabelEnum.NEWS: '''''',
+ ToolLabelEnum.MEDICAL: '''''',
+ ToolLabelEnum.PRODUCTIVITY: '''''',
+ ToolLabelEnum.EDUCATION: '''''',
+ ToolLabelEnum.BUSINESS: '''''',
+ ToolLabelEnum.ENTERTAINMENT: '''''',
+ ToolLabelEnum.UTILITIES: '''''',
+ ToolLabelEnum.OTHER: ''''''
+}
+
+default_tool_label_dict = {
+ ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]),
+ ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]),
+ ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]),
+ ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]),
+ ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]),
+ ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]),
+ ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]),
+ ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]),
+ ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]),
+ ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]),
+ ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]),
+ ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]),
+ ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]),
+ ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]),
+ ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]),
+ ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]),
+}
+
+default_tool_labels = [v for k, v in default_tool_label_dict.items()]
+default_tool_label_name_list = [label.name for label in default_tool_labels]
diff --git a/api/core/tools/model/errors.py b/api/core/tools/model/errors.py
deleted file mode 100644
index 6e242b349a..0000000000
--- a/api/core/tools/model/errors.py
+++ /dev/null
@@ -1,2 +0,0 @@
-class InvokeModelError(Exception):
- pass
\ No newline at end of file
diff --git a/api/core/tools/model_tools/anthropic.yaml b/api/core/tools/model_tools/anthropic.yaml
deleted file mode 100644
index 4ccb973df5..0000000000
--- a/api/core/tools/model_tools/anthropic.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-provider: anthropic
-label:
- en_US: Anthropic Model Tools
- zh_Hans: Anthropic 模型能力
- pt_BR: Anthropic Model Tools
-models:
- - type: llm
- model: claude-3-sonnet-20240229
- label:
- zh_Hans: Claude3 Sonnet 视觉
- en_US: Claude3 Sonnet Vision
- properties:
- image_parameter_name: image_id
- - type: llm
- model: claude-3-opus-20240229
- label:
- zh_Hans: Claude3 Opus 视觉
- en_US: Claude3 Opus Vision
- properties:
- image_parameter_name: image_id
diff --git a/api/core/tools/model_tools/google.yaml b/api/core/tools/model_tools/google.yaml
deleted file mode 100644
index d81e1b0735..0000000000
--- a/api/core/tools/model_tools/google.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-provider: google
-label:
- en_US: Google Model Tools
- zh_Hans: Google 模型能力
- pt_BR: Google Model Tools
-models:
- - type: llm
- model: gemini-pro-vision
- label:
- zh_Hans: Gemini Pro 视觉
- en_US: Gemini Pro Vision
- properties:
- image_parameter_name: image_id
diff --git a/api/core/tools/model_tools/openai.yaml b/api/core/tools/model_tools/openai.yaml
deleted file mode 100644
index 45cbb295a9..0000000000
--- a/api/core/tools/model_tools/openai.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-provider: openai
-label:
- en_US: OpenAI Model Tools
- zh_Hans: OpenAI 模型能力
- pt_BR: OpenAI Model Tools
-models:
- - type: llm
- model: gpt-4-vision-preview
- label:
- zh_Hans: GPT-4 视觉
- en_US: GPT-4 Vision
- properties:
- image_parameter_name: image_id
diff --git a/api/core/tools/model_tools/zhipuai.yaml b/api/core/tools/model_tools/zhipuai.yaml
deleted file mode 100644
index 19a932eb89..0000000000
--- a/api/core/tools/model_tools/zhipuai.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-provider: zhipuai
-label:
- en_US: ZhipuAI Model Tools
- zh_Hans: ZhipuAI 模型能力
- pt_BR: ZhipuAI Model Tools
-models:
- - type: llm
- model: glm-4v
- label:
- zh_Hans: GLM-4 视觉
- en_US: GLM-4 Vision
- properties:
- image_parameter_name: image_id
diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml
index 5c9454c11c..3b0f78cc76 100644
--- a/api/core/tools/provider/_position.yaml
+++ b/api/core/tools/provider/_position.yaml
@@ -1,21 +1,18 @@
- google
- bing
- duckduckgo
+- searchapi
- searxng
- dalle
- azuredalle
- stability
- wikipedia
-- model.openai
-- model.google
-- model.anthropic
- yahoo
- arxiv
- pubmed
- stablediffusion
- webscraper
- jina
-- model.zhipuai
- aippt
- youtube
- code
diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py
index 11e6e892c9..ae80ad2114 100644
--- a/api/core/tools/provider/api_tool_provider.py
+++ b/api/core/tools/provider/api_tool_provider.py
@@ -1,7 +1,6 @@
-from typing import Any
from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_bundle import ApiBasedToolBundle
+from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolCredentialsOption,
@@ -15,11 +14,11 @@ from extensions.ext_database import db
from models.tools import ApiToolProvider
-class ApiBasedToolProviderController(ToolProviderController):
+class ApiToolProviderController(ToolProviderController):
provider_id: str
@staticmethod
- def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
+ def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
credentials_schema = {
'auth_type': ToolProviderCredentials(
name='auth_type',
@@ -79,9 +78,11 @@ class ApiBasedToolProviderController(ToolProviderController):
else:
raise ValueError(f'invalid auth type {auth_type}')
- return ApiBasedToolProviderController(**{
+ user_name = db_provider.user.name if db_provider.user_id else ''
+
+ return ApiToolProviderController(**{
'identity': {
- 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
+ 'author': user_name,
'name': db_provider.name,
'label': {
'en_US': db_provider.name,
@@ -98,16 +99,10 @@ class ApiBasedToolProviderController(ToolProviderController):
})
@property
- def app_type(self) -> ToolProviderType:
- return ToolProviderType.API_BASED
-
- def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
- pass
+ def provider_type(self) -> ToolProviderType:
+ return ToolProviderType.API
- def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
- pass
-
- def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool:
+ def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
"""
parse tool bundle to tool
@@ -136,7 +131,7 @@ class ApiBasedToolProviderController(ToolProviderController):
'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],
})
- def load_bundled_tools(self, tools: list[ApiBasedToolBundle]) -> list[ApiTool]:
+ def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
"""
load bundled tools
diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py
index 159c94bbf3..2d472e0a93 100644
--- a/api/core/tools/provider/app_tool_provider.py
+++ b/api/core/tools/provider/app_tool_provider.py
@@ -11,10 +11,10 @@ from models.tools import PublishedAppTool
logger = logging.getLogger(__name__)
-class AppBasedToolProviderEntity(ToolProviderController):
+class AppToolProviderEntity(ToolProviderController):
@property
- def app_type(self) -> ToolProviderType:
- return ToolProviderType.APP_BASED
+ def provider_type(self) -> ToolProviderType:
+ return ToolProviderType.APP
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
pass
diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py
index 2bf70bd356..ae806eaff4 100644
--- a/api/core/tools/provider/builtin/_positions.py
+++ b/api/core/tools/provider/builtin/_positions.py
@@ -1,7 +1,7 @@
import os.path
-from core.tools.entities.user_entities import UserToolProvider
-from core.utils.position_helper import get_position_map, sort_by_position_map
+from core.helper.position_helper import get_position_map, sort_by_position_map
+from core.tools.entities.api_entities import UserToolProvider
class BuiltinToolProviderSort:
@@ -13,10 +13,7 @@ class BuiltinToolProviderSort:
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
def name_func(provider: UserToolProvider) -> str:
- if provider.type == UserToolProvider.ProviderType.MODEL:
- return f'model.{provider.name}'
- else:
- return provider.name
+ return provider.name
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
diff --git a/api/core/tools/provider/builtin/aippt/aippt.yaml b/api/core/tools/provider/builtin/aippt/aippt.yaml
index b3ff1f6d98..9b1b45d0f2 100644
--- a/api/core/tools/provider/builtin/aippt/aippt.yaml
+++ b/api/core/tools/provider/builtin/aippt/aippt.yaml
@@ -8,6 +8,9 @@ identity:
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
icon: icon.png
+ tags:
+ - productivity
+ - design
credentials_for_provider:
aippt_access_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.py b/api/core/tools/provider/builtin/arxiv/arxiv.py
index 998128522e..707fc69be3 100644
--- a/api/core/tools/provider/builtin/arxiv/arxiv.py
+++ b/api/core/tools/provider/builtin/arxiv/arxiv.py
@@ -7,7 +7,7 @@ class ArxivProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
ArxivSearchTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -17,4 +17,5 @@ class ArxivProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.yaml b/api/core/tools/provider/builtin/arxiv/arxiv.yaml
index 78c5c161af..d26993b336 100644
--- a/api/core/tools/provider/builtin/arxiv/arxiv.yaml
+++ b/api/core/tools/provider/builtin/arxiv/arxiv.yaml
@@ -8,3 +8,5 @@ identity:
en_US: Access to a vast repository of scientific papers and articles in various fields of research.
zh_Hans: 访问各个研究领域大量科学论文和文章的存储库。
icon: icon.svg
+ tags:
+ - search
diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py
index fb64e07a8c..448d2e8f84 100644
--- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py
+++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py
@@ -68,7 +68,7 @@ class ArxivAPIWrapper(BaseModel):
Args:
query: a plaintext search query
- """ # noqa: E501
+ """
try:
results = self.arxiv_search( # type: ignore
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
diff --git a/api/core/tools/provider/builtin/azuredalle/azuredalle.py b/api/core/tools/provider/builtin/azuredalle/azuredalle.py
index 4278da54ba..2981a54d3c 100644
--- a/api/core/tools/provider/builtin/azuredalle/azuredalle.py
+++ b/api/core/tools/provider/builtin/azuredalle/azuredalle.py
@@ -9,7 +9,7 @@ class AzureDALLEProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
DallE3Tool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
diff --git a/api/core/tools/provider/builtin/azuredalle/azuredalle.yaml b/api/core/tools/provider/builtin/azuredalle/azuredalle.yaml
index 62ce2c21fe..4353e0c486 100644
--- a/api/core/tools/provider/builtin/azuredalle/azuredalle.yaml
+++ b/api/core/tools/provider/builtin/azuredalle/azuredalle.yaml
@@ -10,6 +10,9 @@ identity:
zh_Hans: Azure DALL-E 绘画
pt_BR: Azure DALL-E art
icon: icon.png
+ tags:
+ - image
+ - productivity
credentials_for_provider:
azure_openai_api_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/bing/bing.py b/api/core/tools/provider/builtin/bing/bing.py
index 6e62abfc10..c71128be4a 100644
--- a/api/core/tools/provider/builtin/bing/bing.py
+++ b/api/core/tools/provider/builtin/bing/bing.py
@@ -9,7 +9,7 @@ class BingProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
BingSearchTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).validate_credentials(
diff --git a/api/core/tools/provider/builtin/bing/bing.yaml b/api/core/tools/provider/builtin/bing/bing.yaml
index 35cd729208..1ab17d5294 100644
--- a/api/core/tools/provider/builtin/bing/bing.yaml
+++ b/api/core/tools/provider/builtin/bing/bing.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: Bing 搜索
pt_BR: Bing Search
icon: icon.svg
+ tags:
+ - search
credentials_for_provider:
subscription_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/brave/brave.py b/api/core/tools/provider/builtin/brave/brave.py
index e26b28b46a..e5eada80ee 100644
--- a/api/core/tools/provider/builtin/brave/brave.py
+++ b/api/core/tools/provider/builtin/brave/brave.py
@@ -9,7 +9,7 @@ class BraveProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
BraveSearchTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -19,4 +19,5 @@ class BraveProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/brave/brave.yaml b/api/core/tools/provider/builtin/brave/brave.yaml
index d1b7ff1086..93d315f839 100644
--- a/api/core/tools/provider/builtin/brave/brave.yaml
+++ b/api/core/tools/provider/builtin/brave/brave.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: Brave
pt_BR: Brave
icon: icon.svg
+ tags:
+ - search
credentials_for_provider:
brave_search_api_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py
index f5e42e766d..0865bc700a 100644
--- a/api/core/tools/provider/builtin/chart/chart.py
+++ b/api/core/tools/provider/builtin/chart/chart.py
@@ -44,7 +44,7 @@ class ChartProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
LinearChartTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -54,4 +54,5 @@ class ChartProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/chart/chart.yaml b/api/core/tools/provider/builtin/chart/chart.yaml
index 7aac32b3bb..ad0d9a6cd6 100644
--- a/api/core/tools/provider/builtin/chart/chart.yaml
+++ b/api/core/tools/provider/builtin/chart/chart.yaml
@@ -10,4 +10,8 @@ identity:
zh_Hans: 图表生成是一个用于生成可视化图表的工具,你可以通过它来生成柱状图、折线图、饼图等各类图表
pt_BR: O Gerador de gráficos é uma ferramenta para gerar gráficos estatísticos como gráfico de barras, gráfico de linhas, gráfico de pizza, etc.
icon: icon.png
+ tags:
+ - design
+ - productivity
+ - utilities
credentials_for_provider:
diff --git a/api/core/tools/provider/builtin/chart/tools/bar.yaml b/api/core/tools/provider/builtin/chart/tools/bar.yaml
index bc34f2a5ec..ee7405f681 100644
--- a/api/core/tools/provider/builtin/chart/tools/bar.yaml
+++ b/api/core/tools/provider/builtin/chart/tools/bar.yaml
@@ -21,9 +21,9 @@ parameters:
zh_Hans: 数据
pt_BR: dados
human_description:
- en_US: data for generating bar chart
- zh_Hans: 用于生成柱状图的数据
- pt_BR: dados para gerar gráfico de barras
+ en_US: data for generating chart, each number should be separated by ";"
+ zh_Hans: 用于生成柱状图的数据,每个数字之间用 ";" 分隔
+ pt_BR: dados para gerar gráfico de barras, cada número deve ser separado por ";"
llm_description: data for generating bar chart, data should be a string contains a list of numbers like "1;2;3;4;5"
form: llm
- name: x_axis
@@ -34,8 +34,8 @@ parameters:
zh_Hans: x 轴
pt_BR: Eixo X
human_description:
- en_US: X axis for bar chart
- zh_Hans: 柱状图的 x 轴
- pt_BR: Eixo X para gráfico de barras
+ en_US: X axis for chart, each text should be separated by ";"
+ zh_Hans: 柱状图的 x 轴,每个文本之间用 ";" 分隔
+ pt_BR: Eixo X para gráfico de barras, cada texto deve ser separado por ";"
llm_description: x axis for bar chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
form: llm
diff --git a/api/core/tools/provider/builtin/chart/tools/line.yaml b/api/core/tools/provider/builtin/chart/tools/line.yaml
index 9994cbb80b..35ebe3b68b 100644
--- a/api/core/tools/provider/builtin/chart/tools/line.yaml
+++ b/api/core/tools/provider/builtin/chart/tools/line.yaml
@@ -21,9 +21,9 @@ parameters:
zh_Hans: 数据
pt_BR: dados
human_description:
- en_US: data for generating linear chart
- zh_Hans: 用于生成线性图表的数据
- pt_BR: dados para gerar gráfico linear
+ en_US: data for generating chart, each number should be separated by ";"
+ zh_Hans: 用于生成线性图表的数据,每个数字之间用 ";" 分隔
+ pt_BR: dados para gerar gráfico linear, cada número deve ser separado por ";"
llm_description: data for generating linear chart, data should be a string contains a list of numbers like "1;2;3;4;5"
form: llm
- name: x_axis
@@ -34,8 +34,8 @@ parameters:
zh_Hans: x 轴
pt_BR: Eixo X
human_description:
- en_US: X axis for linear chart
- zh_Hans: 线性图表的 x 轴
- pt_BR: Eixo X para gráfico linear
+ en_US: X axis for chart, each text should be separated by ";"
+ zh_Hans: 线性图表的 x 轴,每个文本之间用 ";" 分隔
+ pt_BR: Eixo X para gráfico linear, cada texto deve ser separado por ";"
llm_description: x axis for linear chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
form: llm
diff --git a/api/core/tools/provider/builtin/chart/tools/pie.yaml b/api/core/tools/provider/builtin/chart/tools/pie.yaml
index 7d6647ef4a..541715cb7d 100644
--- a/api/core/tools/provider/builtin/chart/tools/pie.yaml
+++ b/api/core/tools/provider/builtin/chart/tools/pie.yaml
@@ -21,9 +21,9 @@ parameters:
zh_Hans: 数据
pt_BR: dados
human_description:
- en_US: data for generating pie chart
- zh_Hans: 用于生成饼图的数据
- pt_BR: dados para gerar gráfico de pizza
+ en_US: data for generating chart, each number should be separated by ";"
+ zh_Hans: 用于生成饼图的数据,每个数字之间用 ";" 分隔
+ pt_BR: dados para gerar gráfico de pizza, cada número deve ser separado por ";"
llm_description: data for generating pie chart, data should be a string contains a list of numbers like "1;2;3;4;5"
form: llm
- name: categories
@@ -34,8 +34,8 @@ parameters:
zh_Hans: 分类
pt_BR: Categorias
human_description:
- en_US: Categories for pie chart
- zh_Hans: 饼图的分类
- pt_BR: Categorias para gráfico de pizza
+ en_US: Categories for chart, each category should be separated by ";"
+ zh_Hans: 饼图的分类,每个分类之间用 ";" 分隔
+ pt_BR: Categorias para gráfico de pizza, cada categoria deve ser separada por ";"
llm_description: categories for pie chart, categories should be a string contains a list of texts like "a;b;c;1;2" in order to match the data, each category should be split by ";"
form: llm
diff --git a/api/core/tools/provider/builtin/code/code.py b/api/core/tools/provider/builtin/code/code.py
index fae5ecf769..211417c9a4 100644
--- a/api/core/tools/provider/builtin/code/code.py
+++ b/api/core/tools/provider/builtin/code/code.py
@@ -5,4 +5,4 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class CodeToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
- pass
\ No newline at end of file
+ pass
diff --git a/api/core/tools/provider/builtin/code/code.yaml b/api/core/tools/provider/builtin/code/code.yaml
index b0fd0dd587..2640a7087e 100644
--- a/api/core/tools/provider/builtin/code/code.yaml
+++ b/api/core/tools/provider/builtin/code/code.yaml
@@ -10,4 +10,6 @@ identity:
zh_Hans: 运行一段代码并返回结果。
pt_BR: Execute um trecho de código e obtenha o resultado de volta.
icon: icon.svg
+ tags:
+ - productivity
credentials_for_provider:
diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py
index ae9b1cb612..37645bf0d0 100644
--- a/api/core/tools/provider/builtin/code/tools/simple_code.py
+++ b/api/core/tools/provider/builtin/code/tools/simple_code.py
@@ -1,6 +1,6 @@
from typing import Any
-from core.helper.code_executor.code_executor import CodeExecutor
+from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -11,10 +11,10 @@ class SimpleCode(BuiltinTool):
invoke simple code
"""
- language = tool_parameters.get('language', 'python3')
+ language = tool_parameters.get('language', CodeLanguage.PYTHON3)
code = tool_parameters.get('code', '')
- if language not in ['python3', 'javascript']:
+ if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]:
raise ValueError(f'Only python3 and javascript are supported, not {language}')
result = CodeExecutor.execute_code(language, '', code)
diff --git a/api/core/tools/provider/builtin/dalle/dalle.py b/api/core/tools/provider/builtin/dalle/dalle.py
index 34a24a7425..1c8019364d 100644
--- a/api/core/tools/provider/builtin/dalle/dalle.py
+++ b/api/core/tools/provider/builtin/dalle/dalle.py
@@ -9,7 +9,7 @@ class DALLEProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
DallE2Tool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -21,4 +21,5 @@ class DALLEProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/dalle/dalle.yaml b/api/core/tools/provider/builtin/dalle/dalle.yaml
index a99401d82d..f09a9177f2 100644
--- a/api/core/tools/provider/builtin/dalle/dalle.yaml
+++ b/api/core/tools/provider/builtin/dalle/dalle.yaml
@@ -10,6 +10,9 @@ identity:
zh_Hans: DALL-E 绘画
pt_BR: DALL-E art
icon: icon.png
+ tags:
+ - image
+ - productivity
credentials_for_provider:
openai_api_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py
index e41cbd9f65..450e782281 100644
--- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py
+++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py
@@ -1,8 +1,8 @@
from base64 import b64decode
-from os.path import join
from typing import Any, Union
from openai import OpenAI
+from yarl import URL
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -23,7 +23,7 @@ class DallE2Tool(BuiltinTool):
if not openai_base_url:
openai_base_url = None
else:
- openai_base_url = join(openai_base_url, 'v1')
+ openai_base_url = str(URL(openai_base_url) / 'v1')
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py
index dc53025b02..87d18f68e0 100644
--- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py
+++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py
@@ -1,8 +1,8 @@
from base64 import b64decode
-from os.path import join
from typing import Any, Union
from openai import OpenAI
+from yarl import URL
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -23,7 +23,7 @@ class DallE3Tool(BuiltinTool):
if not openai_base_url:
openai_base_url = None
else:
- openai_base_url = join(openai_base_url, 'v1')
+ openai_base_url = str(URL(openai_base_url) / 'v1')
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
diff --git a/api/core/tools/provider/builtin/devdocs/devdocs.py b/api/core/tools/provider/builtin/devdocs/devdocs.py
index 25cbe4d053..95d7939d0d 100644
--- a/api/core/tools/provider/builtin/devdocs/devdocs.py
+++ b/api/core/tools/provider/builtin/devdocs/devdocs.py
@@ -7,7 +7,7 @@ class DevDocsProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
SearchDevDocsTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -18,4 +18,5 @@ class DevDocsProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/devdocs/devdocs.yaml b/api/core/tools/provider/builtin/devdocs/devdocs.yaml
index 1db226fc4b..7552f5a497 100644
--- a/api/core/tools/provider/builtin/devdocs/devdocs.yaml
+++ b/api/core/tools/provider/builtin/devdocs/devdocs.yaml
@@ -8,3 +8,6 @@ identity:
en_US: Get official developer documentations on DevDocs.
zh_Hans: 从DevDocs获取官方开发者文档。
icon: icon.svg
+ tags:
+ - search
+ - productivity
diff --git a/api/core/tools/provider/builtin/dingtalk/dingtalk.yaml b/api/core/tools/provider/builtin/dingtalk/dingtalk.yaml
index ebe2e4fbaf..c922c140a8 100644
--- a/api/core/tools/provider/builtin/dingtalk/dingtalk.yaml
+++ b/api/core/tools/provider/builtin/dingtalk/dingtalk.yaml
@@ -10,4 +10,7 @@ identity:
zh_Hans: 钉钉群机器人
pt_BR: DingTalk group robot
icon: icon.svg
+ tags:
+ - social
+ - productivity
credentials_for_provider:
diff --git a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py
index 3e9b57ece7..6df8678d30 100644
--- a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py
+++ b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py
@@ -7,7 +7,7 @@ class DuckDuckGoProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
DuckDuckGoSearchTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -17,4 +17,5 @@ class DuckDuckGoProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.yaml b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.yaml
index 8778dde625..f3faa06045 100644
--- a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.yaml
+++ b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.yaml
@@ -8,3 +8,5 @@ identity:
en_US: A privacy-focused search engine.
zh_Hans: 一个注重隐私的搜索引擎。
icon: icon.svg
+ tags:
+ - search
diff --git a/api/core/tools/provider/builtin/feishu/feishu.py b/api/core/tools/provider/builtin/feishu/feishu.py
index 13303dbe64..72a9333619 100644
--- a/api/core/tools/provider/builtin/feishu/feishu.py
+++ b/api/core/tools/provider/builtin/feishu/feishu.py
@@ -5,4 +5,3 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class FeishuProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
FeishuGroupBotTool()
- pass
diff --git a/api/core/tools/provider/builtin/feishu/feishu.yaml b/api/core/tools/provider/builtin/feishu/feishu.yaml
index a1fcd38047..a029c7edb8 100644
--- a/api/core/tools/provider/builtin/feishu/feishu.yaml
+++ b/api/core/tools/provider/builtin/feishu/feishu.yaml
@@ -10,4 +10,7 @@ identity:
zh_Hans: 飞书群机器人
pt_BR: Feishu group bot
icon: icon.svg
+ tags:
+ - social
+ - productivity
credentials_for_provider:
diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.py b/api/core/tools/provider/builtin/firecrawl/firecrawl.py
index 20ab978b8d..adcb7ebdd6 100644
--- a/api/core/tools/provider/builtin/firecrawl/firecrawl.py
+++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.py
@@ -8,7 +8,7 @@ class FirecrawlProvider(BuiltinToolProviderController):
try:
# Example validation using the Crawl tool
CrawlTool().fork_tool_runtime(
- meta={"credentials": credentials}
+ runtime={"credentials": credentials}
).invoke(
user_id='',
tool_parameters={
@@ -20,4 +20,5 @@ class FirecrawlProvider(BuiltinToolProviderController):
}
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml b/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml
index 67fad19159..311283dcb5 100644
--- a/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml
+++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml
@@ -8,6 +8,9 @@ identity:
en_US: Firecrawl API integration for web crawling and scraping.
zh_CN: Firecrawl API 集成,用于网页爬取和数据抓取。
icon: icon.svg
+ tags:
+ - search
+ - utilities
credentials_for_provider:
firecrawl_api_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/gaode/gaode.yaml b/api/core/tools/provider/builtin/gaode/gaode.yaml
index bca53b22e9..2eb3b161a2 100644
--- a/api/core/tools/provider/builtin/gaode/gaode.yaml
+++ b/api/core/tools/provider/builtin/gaode/gaode.yaml
@@ -10,6 +10,11 @@ identity:
zh_Hans: 高德开放平台服务工具包。
pt_BR: Kit de ferramentas de serviço Autonavi Open Platform.
icon: icon.svg
+ tags:
+ - utilities
+ - productivity
+ - travel
+ - weather
credentials_for_provider:
api_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/github/github.yaml b/api/core/tools/provider/builtin/github/github.yaml
index d529e639cc..c3d85fc3f6 100644
--- a/api/core/tools/provider/builtin/github/github.yaml
+++ b/api/core/tools/provider/builtin/github/github.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: GitHub是一个在线软件源代码托管服务平台。
pt_BR: GitHub é uma plataforma online para serviços de hospedagem de código fonte de software.
icon: icon.svg
+ tags:
+ - utilities
credentials_for_provider:
access_tokens:
type: secret-input
diff --git a/api/core/tools/provider/builtin/google/google.py b/api/core/tools/provider/builtin/google/google.py
index 3900804b45..8f4b9a4a4e 100644
--- a/api/core/tools/provider/builtin/google/google.py
+++ b/api/core/tools/provider/builtin/google/google.py
@@ -9,7 +9,7 @@ class GoogleProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
GoogleSearchTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -20,4 +20,5 @@ class GoogleProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/google/google.yaml b/api/core/tools/provider/builtin/google/google.yaml
index 43b75b51cd..afb4d5b214 100644
--- a/api/core/tools/provider/builtin/google/google.yaml
+++ b/api/core/tools/provider/builtin/google/google.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: GoogleSearch
pt_BR: Google
icon: icon.svg
+ tags:
+ - search
credentials_for_provider:
serpapi_api_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/google/tools/google_search.py b/api/core/tools/provider/builtin/google/tools/google_search.py
index 0b1978ad3e..87c2cc5796 100644
--- a/api/core/tools/provider/builtin/google/tools/google_search.py
+++ b/api/core/tools/provider/builtin/google/tools/google_search.py
@@ -1,39 +1,20 @@
-import os
-import sys
from typing import Any, Union
-from serpapi import GoogleSearch
+import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
-
-class HiddenPrints:
- """Context manager to hide prints."""
-
- def __enter__(self) -> None:
- """Open file to pipe stdout to."""
- self._original_stdout = sys.stdout
- sys.stdout = open(os.devnull, "w")
-
- def __exit__(self, *_: Any) -> None:
- """Close file that stdout was piped to."""
- sys.stdout.close()
- sys.stdout = self._original_stdout
+SERP_API_URL = "https://serpapi.com/search"
class SerpAPI:
"""
SerpAPI tool provider.
"""
-
- search_engine: Any #: :meta private:
- serpapi_api_key: str = None
-
def __init__(self, api_key: str) -> None:
"""Initialize SerpAPI tool provider."""
self.serpapi_api_key = api_key
- self.search_engine = GoogleSearch
def run(self, query: str, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result."""
@@ -43,114 +24,76 @@ class SerpAPI:
def results(self, query: str) -> dict:
"""Run query through SerpAPI and return the raw result."""
params = self.get_params(query)
- with HiddenPrints():
- search = self.search_engine(params)
- res = search.get_dict()
- return res
+ response = requests.get(url=SERP_API_URL, params=params)
+ response.raise_for_status()
+ return response.json()
def get_params(self, query: str) -> dict[str, str]:
"""Get parameters for SerpAPI."""
- _params = {
+ params = {
"api_key": self.serpapi_api_key,
"q": query,
- }
- params = {
"engine": "google",
"google_domain": "google.com",
"gl": "us",
- "hl": "en",
- **_params
+ "hl": "en"
}
return params
@staticmethod
def _process_response(res: dict, typ: str) -> str:
- """Process response from SerpAPI."""
- if "error" in res.keys():
+ """
+ Process response from SerpAPI.
+ SerpAPI doc: https://serpapi.com/search-api
+ Google search main results are called organic results
+ """
+ if "error" in res:
raise ValueError(f"Got error from SerpAPI: {res['error']}")
-
+ toret = ""
if typ == "text":
- toret = ""
- if "answer_box" in res.keys() and type(res["answer_box"]) == list:
- res["answer_box"] = res["answer_box"][0] + "\n"
- if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
- toret += res["answer_box"]["answer"] + "\n"
- if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
- toret += res["answer_box"]["snippet"] + "\n"
- if (
- "answer_box" in res.keys()
- and "snippet_highlighted_words" in res["answer_box"].keys()
- ):
- for item in res["answer_box"]["snippet_highlighted_words"]:
- toret += item + "\n"
- if (
- "sports_results" in res.keys()
- and "game_spotlight" in res["sports_results"].keys()
- ):
- toret += res["sports_results"]["game_spotlight"] + "\n"
- if (
- "shopping_results" in res.keys()
- and "title" in res["shopping_results"][0].keys()
- ):
- toret += res["shopping_results"][:3] + "\n"
- if (
- "knowledge_graph" in res.keys()
- and "description" in res["knowledge_graph"].keys()
- ):
- toret = res["knowledge_graph"]["description"] + "\n"
- if "snippet" in res["organic_results"][0].keys():
- for item in res["organic_results"]:
- toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n"
- if (
- "images_results" in res.keys()
- and "thumbnail" in res["images_results"][0].keys()
- ):
- thumbnails = [item["thumbnail"] for item in res["images_results"][:10]]
- toret = thumbnails
- if toret == "":
- toret = "No good search result found"
+ if "knowledge_graph" in res and "description" in res["knowledge_graph"]:
+ toret += res["knowledge_graph"]["description"] + "\n"
+ if "organic_results" in res:
+ snippets = [
+ f"content: {item.get('snippet')}\nlink: {item.get('link')}"
+ for item in res["organic_results"]
+ if "snippet" in item
+ ]
+ toret += "\n".join(snippets)
elif typ == "link":
- if "knowledge_graph" in res.keys() and "title" in res["knowledge_graph"].keys() \
- and "description_link" in res["knowledge_graph"].keys():
- toret = res["knowledge_graph"]["description_link"]
- elif "knowledge_graph" in res.keys() and "see_results_about" in res["knowledge_graph"].keys() \
- and len(res["knowledge_graph"]["see_results_about"]) > 0:
- see_result_about = res["knowledge_graph"]["see_results_about"]
- toret = ""
- for item in see_result_about:
- if "name" not in item.keys() or "link" not in item.keys():
- continue
- toret += f"[{item['name']}]({item['link']})\n"
- elif "organic_results" in res.keys() and len(res["organic_results"]) > 0:
- organic_results = res["organic_results"]
- toret = ""
- for item in organic_results:
- if "title" not in item.keys() or "link" not in item.keys():
- continue
- toret += f"[{item['title']}]({item['link']})\n"
- elif "related_questions" in res.keys() and len(res["related_questions"]) > 0:
- related_questions = res["related_questions"]
- toret = ""
- for item in related_questions:
- if "question" not in item.keys() or "link" not in item.keys():
- continue
- toret += f"[{item['question']}]({item['link']})\n"
- elif "related_searches" in res.keys() and len(res["related_searches"]) > 0:
- related_searches = res["related_searches"]
- toret = ""
- for item in related_searches:
- if "query" not in item.keys() or "link" not in item.keys():
- continue
- toret += f"[{item['query']}]({item['link']})\n"
- else:
- toret = "No good search result found"
+ if "knowledge_graph" in res and "source" in res["knowledge_graph"]:
+ toret += res["knowledge_graph"]["source"]["link"]
+ elif "organic_results" in res:
+ links = [
+ f"[{item['title']}]({item['link']})\n"
+ for item in res["organic_results"]
+ if "title" in item and "link" in item
+ ]
+ toret += "\n".join(links)
+ elif "related_questions" in res:
+ questions = [
+ f"[{item['question']}]({item['link']})\n"
+ for item in res["related_questions"]
+ if "question" in item and "link" in item
+ ]
+ toret += "\n".join(questions)
+ elif "related_searches" in res:
+ searches = [
+ f"[{item['query']}]({item['link']})\n"
+ for item in res["related_searches"]
+ if "query" in item and "link" in item
+ ]
+ toret += "\n".join(searches)
+ if not toret:
+ toret = "No good search result found"
return toret
+
class GoogleSearchTool(BuiltinTool):
- def _invoke(self,
+ def _invoke(self,
user_id: str,
- tool_parameters: dict[str, Any],
- ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ tool_parameters: dict[str, Any],
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
@@ -161,4 +104,3 @@ class GoogleSearchTool(BuiltinTool):
if result_type == 'text':
return self.create_text_message(text=result)
return self.create_link_message(link=result)
-
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/jina/jina.py b/api/core/tools/provider/builtin/jina/jina.py
index ed1de6f6c1..b1a8d62138 100644
--- a/api/core/tools/provider/builtin/jina/jina.py
+++ b/api/core/tools/provider/builtin/jina/jina.py
@@ -1,5 +1,6 @@
from typing import Any
+from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@@ -9,4 +10,9 @@ class GoogleProvider(BuiltinToolProviderController):
try:
pass
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
+ def _get_tool_labels(self) -> list[ToolLabelEnum]:
+ return [
+ ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY
+ ]
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/jina/jina.yaml b/api/core/tools/provider/builtin/jina/jina.yaml
index 6ae3330f40..67ad32a47a 100644
--- a/api/core/tools/provider/builtin/jina/jina.yaml
+++ b/api/core/tools/provider/builtin/jina/jina.yaml
@@ -2,12 +2,15 @@ identity:
author: Dify
name: jina
label:
- en_US: JinaReader
- zh_Hans: JinaReader
- pt_BR: JinaReader
+ en_US: Jina
+ zh_Hans: Jina
+ pt_BR: Jina
description:
- en_US: Convert any URL to an LLM-friendly input. Experience improved output for your agent and RAG systems at no cost.
- zh_Hans: 将任何 URL 转换为 LLM 友好的输入。无需付费即可体验为您的 Agent 和 RAG 系统提供的改进输出。
- pt_BR: Converta qualquer URL em uma entrada amigável ao LLM. Experimente uma saída aprimorada para seus sistemas de agente e RAG sem custo.
+ en_US: Convert any URL to an LLM-friendly input or perform searches on the web for grounding information. Experience improved output for your agent and RAG systems at no cost.
+ zh_Hans: 将任何URL转换为LLM易读的输入或在网页上搜索引擎上搜索引擎。
+ pt_BR: Converte qualquer URL em uma entrada LLm-fácil de ler ou realize pesquisas na web para obter informação de grounding. Tenha uma experiência melhor para seu agente e sistemas RAG sem custo.
icon: icon.svg
+ tags:
+ - search
+ - productivity
credentials_for_provider:
diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py
index fd29a00aa5..beb05717ea 100644
--- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py
+++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py
@@ -23,6 +23,14 @@ class JinaReaderTool(BuiltinTool):
'Accept': 'application/json'
}
+ target_selector = tool_parameters.get('target_selector', None)
+ if target_selector is not None:
+ headers['X-Target-Selector'] = target_selector
+
+ wait_for_selector = tool_parameters.get('wait_for_selector', None)
+ if wait_for_selector is not None:
+ headers['X-Wait-For-Selector'] = wait_for_selector
+
response = ssrf_proxy.get(
str(URL(self._jina_reader_endpoint + url)),
headers=headers,
diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml
index 38d66292df..73cacb7fde 100644
--- a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml
+++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml
@@ -25,6 +25,32 @@ parameters:
pt_BR: used for linking to webpages
llm_description: url for scraping
form: llm
+ - name: target_selector
+ type: string
+ required: false
+ label:
+ en_US: Target selector
+ zh_Hans: 目标选择器
+ pt_BR: Seletor de destino
+ human_description:
+ en_US: css selector for scraping specific elements
+ zh_Hans: css 选择器用于抓取特定元素
+ pt_BR: css selector for scraping specific elements
+ llm_description: css selector of the target element to scrape
+ form: form
+ - name: wait_for_selector
+ type: string
+ required: false
+ label:
+ en_US: Wait for selector
+ zh_Hans: 等待选择器
+ pt_BR: Aguardar por seletor
+ human_description:
+ en_US: css selector for waiting for specific elements
+ zh_Hans: css 选择器用于等待特定元素
+ pt_BR: css selector for waiting for specific elements
+ llm_description: css selector of the target element to wait for
+ form: form
- name: summary
type: boolean
required: false
diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.py b/api/core/tools/provider/builtin/jina/tools/jina_search.py
new file mode 100644
index 0000000000..cfe36e6a3c
--- /dev/null
+++ b/api/core/tools/provider/builtin/jina/tools/jina_search.py
@@ -0,0 +1,30 @@
+from typing import Any, Union
+
+from yarl import URL
+
+from core.helper import ssrf_proxy
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class JinaSearchTool(BuiltinTool):
+ _jina_search_endpoint = 'https://s.jina.ai/'
+
+ def _invoke(
+ self,
+ user_id: str,
+ tool_parameters: dict[str, Any],
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ query = tool_parameters['query']
+
+ headers = {
+ 'Accept': 'application/json'
+ }
+
+ response = ssrf_proxy.get(
+ str(URL(self._jina_search_endpoint + query)),
+ headers=headers,
+ timeout=(10, 60)
+ )
+
+ return self.create_text_message(response.text)
diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.yaml b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml
new file mode 100644
index 0000000000..5ad70c03f3
--- /dev/null
+++ b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml
@@ -0,0 +1,21 @@
+identity:
+ name: jina_search
+ author: Dify
+ label:
+ en_US: JinaSearch
+ zh_Hans: JinaSearch
+ pt_BR: JinaSearch
+description:
+ human:
+ en_US: Search on the web and get the top 5 results. Useful for grounding using information from the web.
+ llm: A tool for searching results on the web for grounding. Input should be a simple question.
+parameters:
+ - name: query
+ type: string
+ required: true
+ label:
+ en_US: Question (Query)
+ human_description:
+ en_US: used to find information on the web
+ llm_description: simple question to ask on the web
+ form: llm
diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.py b/api/core/tools/provider/builtin/judge0ce/judge0ce.py
index c00747868b..bac6576797 100644
--- a/api/core/tools/provider/builtin/judge0ce/judge0ce.py
+++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.py
@@ -9,7 +9,7 @@ class Judge0CEProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
ExecuteCodeTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -20,4 +20,5 @@ class Judge0CEProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.yaml b/api/core/tools/provider/builtin/judge0ce/judge0ce.yaml
index 5f0a471827..9ff8aaac6d 100644
--- a/api/core/tools/provider/builtin/judge0ce/judge0ce.yaml
+++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.yaml
@@ -10,6 +10,9 @@ identity:
zh_Hans: Judge0 CE 是一个开源的代码执行系统。支持多种语言,包括 C、C++、Java、Python、Ruby 等。
pt_BR: Judge0 CE é um sistema de execução de código de código aberto. Suporta várias linguagens, incluindo C, C++, Java, Python, Ruby, etc.
icon: icon.svg
+ tags:
+ - utilities
+ - other
credentials_for_provider:
X-RapidAPI-Key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/maths/maths.yaml b/api/core/tools/provider/builtin/maths/maths.yaml
index 5bd892a927..35c2380e29 100644
--- a/api/core/tools/provider/builtin/maths/maths.yaml
+++ b/api/core/tools/provider/builtin/maths/maths.yaml
@@ -10,3 +10,6 @@ identity:
zh_Hans: 一个用于数学计算的工具。
pt_BR: A tool for maths.
icon: icon.svg
+ tags:
+ - utilities
+ - productivity
diff --git a/api/core/tools/provider/builtin/openweather/openweather.yaml b/api/core/tools/provider/builtin/openweather/openweather.yaml
index 60bb33c36d..d4b66f87f9 100644
--- a/api/core/tools/provider/builtin/openweather/openweather.yaml
+++ b/api/core/tools/provider/builtin/openweather/openweather.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: 基于open weather的天气查询工具包
pt_BR: Kit de consulta de clima baseado no Open Weather
icon: icon.svg
+ tags:
+ - weather
credentials_for_provider:
api_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/pubmed/pubmed.py b/api/core/tools/provider/builtin/pubmed/pubmed.py
index 663617c0c1..05cd171b87 100644
--- a/api/core/tools/provider/builtin/pubmed/pubmed.py
+++ b/api/core/tools/provider/builtin/pubmed/pubmed.py
@@ -7,7 +7,7 @@ class PubMedProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
PubMedSearchTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -17,4 +17,5 @@ class PubMedProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/pubmed/pubmed.yaml b/api/core/tools/provider/builtin/pubmed/pubmed.yaml
index 971a6fb204..5f8303147c 100644
--- a/api/core/tools/provider/builtin/pubmed/pubmed.yaml
+++ b/api/core/tools/provider/builtin/pubmed/pubmed.yaml
@@ -8,3 +8,6 @@ identity:
en_US: A search engine for biomedical literature.
zh_Hans: 一款生物医学文献搜索引擎。
icon: icon.svg
+ tags:
+ - medical
+ - search
diff --git a/api/core/tools/provider/builtin/qrcode/qrcode.yaml b/api/core/tools/provider/builtin/qrcode/qrcode.yaml
index c117c3de74..82e2a06e15 100644
--- a/api/core/tools/provider/builtin/qrcode/qrcode.yaml
+++ b/api/core/tools/provider/builtin/qrcode/qrcode.yaml
@@ -10,3 +10,5 @@ identity:
zh_Hans: 一个二维码工具
pt_BR: A tool for generating QR code (quick-response code) image.
icon: icon.svg
+ tags:
+ - utilities
diff --git a/api/core/tools/provider/builtin/searchapi/_assets/icon.svg b/api/core/tools/provider/builtin/searchapi/_assets/icon.svg
new file mode 100644
index 0000000000..7660b2f351
--- /dev/null
+++ b/api/core/tools/provider/builtin/searchapi/_assets/icon.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/searchapi/searchapi.py b/api/core/tools/provider/builtin/searchapi/searchapi.py
new file mode 100644
index 0000000000..6fa4f05acd
--- /dev/null
+++ b/api/core/tools/provider/builtin/searchapi/searchapi.py
@@ -0,0 +1,23 @@
+from typing import Any
+
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin.searchapi.tools.google import GoogleTool
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+class SearchAPIProvider(BuiltinToolProviderController):
+ def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+ try:
+ GoogleTool().fork_tool_runtime(
+ runtime={
+ "credentials": credentials,
+ }
+ ).invoke(
+ user_id='',
+ tool_parameters={
+ "query": "SearchApi dify",
+ "result_type": "link"
+ },
+ )
+ except Exception as e:
+ raise ToolProviderCredentialValidationError(str(e))
diff --git a/api/core/tools/provider/builtin/searchapi/searchapi.yaml b/api/core/tools/provider/builtin/searchapi/searchapi.yaml
new file mode 100644
index 0000000000..c2fa3f398e
--- /dev/null
+++ b/api/core/tools/provider/builtin/searchapi/searchapi.yaml
@@ -0,0 +1,34 @@
+identity:
+ author: SearchApi
+ name: searchapi
+ label:
+ en_US: SearchApi
+ zh_Hans: SearchApi
+ pt_BR: SearchApi
+ description:
+ en_US: SearchApi is a robust real-time SERP API delivering structured data from a collection of search engines including Google Search, Google Jobs, YouTube, Google News, and many more.
+ zh_Hans: SearchApi 是一个强大的实时 SERP API,可提供来自 Google 搜索、Google 招聘、YouTube、Google 新闻等搜索引擎集合的结构化数据。
+ pt_BR: SearchApi is a robust real-time SERP API delivering structured data from a collection of search engines including Google Search, Google Jobs, YouTube, Google News, and many more.
+ icon: icon.svg
+ tags:
+ - search
+ - business
+ - news
+ - productivity
+credentials_for_provider:
+ searchapi_api_key:
+ type: secret-input
+ required: true
+ label:
+ en_US: SearchApi API key
+ zh_Hans: SearchApi API key
+ pt_BR: SearchApi API key
+ placeholder:
+ en_US: Please input your SearchApi API key
+ zh_Hans: 请输入你的 SearchApi API key
+ pt_BR: Please input your SearchApi API key
+ help:
+ en_US: Get your SearchApi API key from SearchApi
+ zh_Hans: 从 SearchApi 获取您的 SearchApi API key
+ pt_BR: Get your SearchApi API key from SearchApi
+ url: https://www.searchapi.io/
diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py
new file mode 100644
index 0000000000..d019fe7134
--- /dev/null
+++ b/api/core/tools/provider/builtin/searchapi/tools/google.py
@@ -0,0 +1,104 @@
+from typing import Any, Union
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+SEARCH_API_URL = "https://www.searchapi.io/api/v1/search"
+
+class SearchAPI:
+ """
+ SearchAPI tool provider.
+ """
+
+ def __init__(self, api_key: str) -> None:
+ """Initialize SearchAPI tool provider."""
+ self.searchapi_api_key = api_key
+
+ def run(self, query: str, **kwargs: Any) -> str:
+ """Run query through SearchAPI and parse result."""
+ type = kwargs.get("result_type", "text")
+ return self._process_response(self.results(query, **kwargs), type=type)
+
+ def results(self, query: str, **kwargs: Any) -> dict:
+ """Run query through SearchAPI and return the raw result."""
+ params = self.get_params(query, **kwargs)
+ response = requests.get(
+ url=SEARCH_API_URL,
+ params=params,
+ headers={"Authorization": f"Bearer {self.searchapi_api_key}"},
+ )
+ response.raise_for_status()
+ return response.json()
+
+ def get_params(self, query: str, **kwargs: Any) -> dict[str, str]:
+ """Get parameters for SearchAPI."""
+ return {
+ "engine": "google",
+ "q": query,
+ **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+ }
+
+ @staticmethod
+ def _process_response(res: dict, type: str) -> str:
+ """Process response from SearchAPI."""
+ if "error" in res.keys():
+ raise ValueError(f"Got error from SearchApi: {res['error']}")
+
+ toret = ""
+ if type == "text":
+ if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
+ toret += res["answer_box"]["answer"] + "\n"
+ if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
+ toret += res["answer_box"]["snippet"] + "\n"
+ if "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
+ toret += res["knowledge_graph"]["description"] + "\n"
+ if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys():
+ for item in res["organic_results"]:
+ toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n"
+ if toret == "":
+ toret = "No good search result found"
+
+ elif type == "link":
+ if "answer_box" in res.keys() and "organic_result" in res["answer_box"].keys():
+ if "title" in res["answer_box"]["organic_result"].keys():
+ toret = f"[{res['answer_box']['organic_result']['title']}]({res['answer_box']['organic_result']['link']})\n"
+ elif "organic_results" in res.keys() and "link" in res["organic_results"][0].keys():
+ toret = ""
+ for item in res["organic_results"]:
+ toret += f"[{item['title']}]({item['link']})\n"
+ elif "related_questions" in res.keys() and "link" in res["related_questions"][0].keys():
+ toret = ""
+ for item in res["related_questions"]:
+ toret += f"[{item['title']}]({item['link']})\n"
+ elif "related_searches" in res.keys() and "link" in res["related_searches"][0].keys():
+ toret = ""
+ for item in res["related_searches"]:
+ toret += f"[{item['title']}]({item['link']})\n"
+ else:
+ toret = "No good search result found"
+ return toret
+
+class GoogleTool(BuiltinTool):
+ def _invoke(self,
+ user_id: str,
+ tool_parameters: dict[str, Any],
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ """
+ Invoke the SearchApi tool.
+ """
+ query = tool_parameters['query']
+ result_type = tool_parameters['result_type']
+ num = tool_parameters.get("num", 10)
+ google_domain = tool_parameters.get("google_domain", "google.com")
+ gl = tool_parameters.get("gl", "us")
+ hl = tool_parameters.get("hl", "en")
+ location = tool_parameters.get("location", None)
+
+ api_key = self.runtime.credentials['searchapi_api_key']
+ result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location)
+
+ if result_type == 'text':
+ return self.create_text_message(text=result)
+ return self.create_link_message(link=result)
diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.yaml b/api/core/tools/provider/builtin/searchapi/tools/google.yaml
new file mode 100644
index 0000000000..566de84b13
--- /dev/null
+++ b/api/core/tools/provider/builtin/searchapi/tools/google.yaml
@@ -0,0 +1,481 @@
+identity:
+ name: google_search_api
+ author: SearchApi
+ label:
+ en_US: Google Search API
+ zh_Hans: Google Search API
+description:
+ human:
+ en_US: A tool to retrieve answer boxes, knowledge graphs, snippets, and webpages from Google Search engine.
+ zh_Hans: 一种从 Google 搜索引擎检索答案框、知识图、片段和网页的工具。
+ llm: A tool to retrieve answer boxes, knowledge graphs, snippets, and webpages from Google Search engine.
+parameters:
+ - name: query
+ type: string
+ required: true
+ label:
+ en_US: Query
+ zh_Hans: 询问
+ human_description:
+ en_US: Defines the query you want to search.
+ zh_Hans: 定义您要搜索的查询。
+ llm_description: Defines the search query you want to search.
+ form: llm
+ - name: result_type
+ type: select
+ required: true
+ options:
+ - value: text
+ label:
+ en_US: text
+ zh_Hans: 文本
+ - value: link
+ label:
+ en_US: link
+ zh_Hans: 链接
+ default: text
+ label:
+ en_US: Result type
+ zh_Hans: 结果类型
+ human_description:
+ en_US: used for selecting the result type, text or link
+ zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
+ form: form
+ - name: location
+ type: string
+ required: false
+ label:
+ en_US: Location
+ zh_Hans: 询问
+ human_description:
+ en_US: Defines from where you want the search to originate. (For example - New York)
+ zh_Hans: 定义您想要搜索的起始位置。 (例如 - 纽约)
+ llm_description: Defines from where you want the search to originate. (For example - New York)
+ form: llm
+ - name: gl
+ type: select
+ label:
+ en_US: Country
+ zh_Hans: 国家
+ required: false
+ human_description:
+ en_US: Defines the country of the search. Default is "US".
+ zh_Hans: 定义搜索的国家/地区。默认为“美国”。
+ llm_description: Defines the gl parameter of the Google search.
+ form: form
+ default: US
+ options:
+ - value: AR
+ label:
+ en_US: Argentina
+ zh_Hans: 阿根廷
+ pt_BR: Argentina
+ - value: AU
+ label:
+ en_US: Australia
+ zh_Hans: 澳大利亚
+ pt_BR: Australia
+ - value: AT
+ label:
+ en_US: Austria
+ zh_Hans: 奥地利
+ pt_BR: Austria
+ - value: BE
+ label:
+ en_US: Belgium
+ zh_Hans: 比利时
+ pt_BR: Belgium
+ - value: BR
+ label:
+ en_US: Brazil
+ zh_Hans: 巴西
+ pt_BR: Brazil
+ - value: CA
+ label:
+ en_US: Canada
+ zh_Hans: 加拿大
+ pt_BR: Canada
+ - value: CL
+ label:
+ en_US: Chile
+ zh_Hans: 智利
+ pt_BR: Chile
+ - value: CO
+ label:
+ en_US: Colombia
+ zh_Hans: 哥伦比亚
+ pt_BR: Colombia
+ - value: CN
+ label:
+ en_US: China
+ zh_Hans: 中国
+ pt_BR: China
+ - value: CZ
+ label:
+ en_US: Czech Republic
+ zh_Hans: 捷克共和国
+ pt_BR: Czech Republic
+ - value: DK
+ label:
+ en_US: Denmark
+ zh_Hans: 丹麦
+ pt_BR: Denmark
+ - value: FI
+ label:
+ en_US: Finland
+ zh_Hans: 芬兰
+ pt_BR: Finland
+ - value: FR
+ label:
+ en_US: France
+ zh_Hans: 法国
+ pt_BR: France
+ - value: DE
+ label:
+ en_US: Germany
+ zh_Hans: 德国
+ pt_BR: Germany
+ - value: HK
+ label:
+ en_US: Hong Kong
+ zh_Hans: 香港
+ pt_BR: Hong Kong
+ - value: IN
+ label:
+ en_US: India
+ zh_Hans: 印度
+ pt_BR: India
+ - value: ID
+ label:
+ en_US: Indonesia
+ zh_Hans: 印度尼西亚
+ pt_BR: Indonesia
+ - value: IT
+ label:
+ en_US: Italy
+ zh_Hans: 意大利
+ pt_BR: Italy
+ - value: JP
+ label:
+ en_US: Japan
+ zh_Hans: 日本
+ pt_BR: Japan
+ - value: KR
+ label:
+ en_US: Korea
+ zh_Hans: 韩国
+ pt_BR: Korea
+ - value: MY
+ label:
+ en_US: Malaysia
+ zh_Hans: 马来西亚
+ pt_BR: Malaysia
+ - value: MX
+ label:
+ en_US: Mexico
+ zh_Hans: 墨西哥
+ pt_BR: Mexico
+ - value: NL
+ label:
+ en_US: Netherlands
+ zh_Hans: 荷兰
+ pt_BR: Netherlands
+ - value: NZ
+ label:
+ en_US: New Zealand
+ zh_Hans: 新西兰
+ pt_BR: New Zealand
+ - value: NO
+ label:
+ en_US: Norway
+ zh_Hans: 挪威
+ pt_BR: Norway
+ - value: PH
+ label:
+ en_US: Philippines
+ zh_Hans: 菲律宾
+ pt_BR: Philippines
+ - value: PL
+ label:
+ en_US: Poland
+ zh_Hans: 波兰
+ pt_BR: Poland
+ - value: PT
+ label:
+ en_US: Portugal
+ zh_Hans: 葡萄牙
+ pt_BR: Portugal
+ - value: RU
+ label:
+ en_US: Russia
+ zh_Hans: 俄罗斯
+ pt_BR: Russia
+ - value: SA
+ label:
+ en_US: Saudi Arabia
+ zh_Hans: 沙特阿拉伯
+ pt_BR: Saudi Arabia
+ - value: SG
+ label:
+ en_US: Singapore
+ zh_Hans: 新加坡
+ pt_BR: Singapore
+ - value: ZA
+ label:
+ en_US: South Africa
+ zh_Hans: 南非
+ pt_BR: South Africa
+ - value: ES
+ label:
+ en_US: Spain
+ zh_Hans: 西班牙
+ pt_BR: Spain
+ - value: SE
+ label:
+ en_US: Sweden
+ zh_Hans: 瑞典
+ pt_BR: Sweden
+ - value: CH
+ label:
+ en_US: Switzerland
+ zh_Hans: 瑞士
+ pt_BR: Switzerland
+ - value: TW
+ label:
+ en_US: Taiwan
+ zh_Hans: 台湾
+ pt_BR: Taiwan
+ - value: TH
+ label:
+ en_US: Thailand
+ zh_Hans: 泰国
+ pt_BR: Thailand
+ - value: TR
+ label:
+ en_US: Turkey
+ zh_Hans: 土耳其
+ pt_BR: Turkey
+ - value: GB
+ label:
+ en_US: United Kingdom
+ zh_Hans: 英国
+ pt_BR: United Kingdom
+ - value: US
+ label:
+ en_US: United States
+ zh_Hans: 美国
+ pt_BR: United States
+ - name: hl
+ type: select
+ label:
+ en_US: Language
+ zh_Hans: 语言
+ human_description:
+ en_US: Defines the interface language of the search. Default is "en".
+ zh_Hans: 定义搜索的界面语言。默认为“en”。
+ required: false
+ default: en
+ form: form
+ options:
+ - value: ar
+ label:
+ en_US: Arabic
+ zh_Hans: 阿拉伯语
+ - value: bg
+ label:
+ en_US: Bulgarian
+ zh_Hans: 保加利亚语
+ - value: ca
+ label:
+ en_US: Catalan
+ zh_Hans: 加泰罗尼亚语
+ - value: zh-cn
+ label:
+ en_US: Chinese (Simplified)
+ zh_Hans: 中文(简体)
+ - value: zh-tw
+ label:
+ en_US: Chinese (Traditional)
+ zh_Hans: 中文(繁体)
+ - value: cs
+ label:
+ en_US: Czech
+ zh_Hans: 捷克语
+ - value: da
+ label:
+ en_US: Danish
+ zh_Hans: 丹麦语
+ - value: nl
+ label:
+ en_US: Dutch
+ zh_Hans: 荷兰语
+ - value: en
+ label:
+ en_US: English
+ zh_Hans: 英语
+ - value: et
+ label:
+ en_US: Estonian
+ zh_Hans: 爱沙尼亚语
+ - value: fi
+ label:
+ en_US: Finnish
+ zh_Hans: 芬兰语
+ - value: fr
+ label:
+ en_US: French
+ zh_Hans: 法语
+ - value: de
+ label:
+ en_US: German
+ zh_Hans: 德语
+ - value: el
+ label:
+ en_US: Greek
+ zh_Hans: 希腊语
+ - value: iw
+ label:
+ en_US: Hebrew
+ zh_Hans: 希伯来语
+ - value: hi
+ label:
+ en_US: Hindi
+ zh_Hans: 印地语
+ - value: hu
+ label:
+ en_US: Hungarian
+ zh_Hans: 匈牙利语
+ - value: id
+ label:
+ en_US: Indonesian
+ zh_Hans: 印尼语
+ - value: it
+ label:
+ en_US: Italian
+ zh_Hans: 意大利语
+ - value: jp
+ label:
+ en_US: Japanese
+ zh_Hans: 日语
+ - value: kn
+ label:
+ en_US: Kannada
+ zh_Hans: 卡纳达语
+ - value: ko
+ label:
+ en_US: Korean
+ zh_Hans: 韩语
+ - value: lv
+ label:
+ en_US: Latvian
+ zh_Hans: 拉脱维亚语
+ - value: lt
+ label:
+ en_US: Lithuanian
+ zh_Hans: 立陶宛语
+ - value: my
+ label:
+ en_US: Malay
+ zh_Hans: 马来语
+ - value: ml
+ label:
+ en_US: Malayalam
+ zh_Hans: 马拉雅拉姆语
+ - value: mr
+ label:
+ en_US: Marathi
+ zh_Hans: 马拉地语
+ - value: "no"
+ label:
+ en_US: Norwegian
+ zh_Hans: 挪威语
+ - value: pl
+ label:
+ en_US: Polish
+ zh_Hans: 波兰语
+ - value: pt-br
+ label:
+ en_US: Portuguese (Brazil)
+ zh_Hans: 葡萄牙语(巴西)
+ - value: pt-pt
+ label:
+ en_US: Portuguese (Portugal)
+ zh_Hans: 葡萄牙语(葡萄牙)
+ - value: pa
+ label:
+ en_US: Punjabi
+ zh_Hans: 旁遮普语
+ - value: ro
+ label:
+ en_US: Romanian
+ zh_Hans: 罗马尼亚语
+ - value: ru
+ label:
+ en_US: Russian
+ zh_Hans: 俄语
+ - value: sr
+ label:
+ en_US: Serbian
+ zh_Hans: 塞尔维亚语
+ - value: sk
+ label:
+ en_US: Slovak
+ zh_Hans: 斯洛伐克语
+ - value: sl
+ label:
+ en_US: Slovenian
+ zh_Hans: 斯洛文尼亚语
+ - value: es
+ label:
+ en_US: Spanish
+ zh_Hans: 西班牙语
+ - value: sv
+ label:
+ en_US: Swedish
+ zh_Hans: 瑞典语
+ - value: ta
+ label:
+ en_US: Tamil
+ zh_Hans: 泰米尔语
+ - value: te
+ label:
+ en_US: Telugu
+ zh_Hans: 泰卢固语
+ - value: th
+ label:
+ en_US: Thai
+ zh_Hans: 泰语
+ - value: tr
+ label:
+ en_US: Turkish
+ zh_Hans: 土耳其语
+ - value: uk
+ label:
+ en_US: Ukrainian
+ zh_Hans: 乌克兰语
+ - value: vi
+ label:
+ en_US: Vietnamese
+ zh_Hans: 越南语
+ - name: google_domain
+ type: string
+ required: false
+ label:
+ en_US: google_domain
+ zh_Hans: google_domain
+ human_description:
+ en_US: Defines the Google domain of the search. Default is "google.com".
+ zh_Hans: 定义搜索的 Google 域。默认为“google.com”。
+ llm_description: Defines Google domain in which you want to search.
+ form: llm
+ - name: num
+ type: number
+ required: false
+ label:
+ en_US: num
+ zh_Hans: num
+ human_description:
+ en_US: Specifies the number of results to display per page. Default is 10. Max number - 100, min - 1.
+ zh_Hans: 指定每页显示的结果数。默认值为 10。最大数量 - 100,最小数量 - 1。
+ llm_description: Specifies the num of results to display per page.
+ form: llm
diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py
new file mode 100644
index 0000000000..1b8cfa7e30
--- /dev/null
+++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py
@@ -0,0 +1,88 @@
+from typing import Any, Union
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+SEARCH_API_URL = "https://www.searchapi.io/api/v1/search"
+
+class SearchAPI:
+ """
+ SearchAPI tool provider.
+ """
+
+ def __init__(self, api_key: str) -> None:
+ """Initialize SearchAPI tool provider."""
+ self.searchapi_api_key = api_key
+
+ def run(self, query: str, **kwargs: Any) -> str:
+ """Run query through SearchAPI and parse result."""
+ type = kwargs.get("result_type", "text")
+ return self._process_response(self.results(query, **kwargs), type=type)
+
+ def results(self, query: str, **kwargs: Any) -> dict:
+ """Run query through SearchAPI and return the raw result."""
+ params = self.get_params(query, **kwargs)
+ response = requests.get(
+ url=SEARCH_API_URL,
+ params=params,
+ headers={"Authorization": f"Bearer {self.searchapi_api_key}"},
+ )
+ response.raise_for_status()
+ return response.json()
+
+ def get_params(self, query: str, **kwargs: Any) -> dict[str, str]:
+ """Get parameters for SearchAPI."""
+ return {
+ "engine": "google_jobs",
+ "q": query,
+ **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+ }
+
+ @staticmethod
+ def _process_response(res: dict, type: str) -> str:
+ """Process response from SearchAPI."""
+ if "error" in res.keys():
+ raise ValueError(f"Got error from SearchApi: {res['error']}")
+
+ toret = ""
+ if type == "text":
+ if "jobs" in res.keys() and "title" in res["jobs"][0].keys():
+ for item in res["jobs"]:
+ toret += "title: " + item["title"] + "\n" + "company_name: " + item["company_name"] + "content: " + item["description"] + "\n"
+ if toret == "":
+ toret = "No good search result found"
+
+ elif type == "link":
+ if "jobs" in res.keys() and "apply_link" in res["jobs"][0].keys():
+ for item in res["jobs"]:
+ toret += f"[{item['title']} - {item['company_name']}]({item['apply_link']})\n"
+ else:
+ toret = "No good search result found"
+ return toret
+
+class GoogleJobsTool(BuiltinTool):
+ def _invoke(self,
+ user_id: str,
+ tool_parameters: dict[str, Any],
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ """
+ Invoke the SearchApi tool.
+ """
+ query = tool_parameters['query']
+ result_type = tool_parameters['result_type']
+ is_remote = tool_parameters.get("is_remote", None)
+ google_domain = tool_parameters.get("google_domain", "google.com")
+ gl = tool_parameters.get("gl", "us")
+ hl = tool_parameters.get("hl", "en")
+ location = tool_parameters.get("location", None)
+
+ ltype = 1 if is_remote else None
+
+ api_key = self.runtime.credentials['searchapi_api_key']
+ result = SearchAPI(api_key).run(query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype)
+
+ if result_type == 'text':
+ return self.create_text_message(text=result)
+ return self.create_link_message(link=result)
diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.yaml b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.yaml
new file mode 100644
index 0000000000..486b193efa
--- /dev/null
+++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.yaml
@@ -0,0 +1,478 @@
+identity:
+ name: google_jobs_api
+ author: SearchApi
+ label:
+ en_US: Google Jobs API
+ zh_Hans: Google Jobs API
+description:
+ human:
+ en_US: A tool to retrieve job titles, company names and description from Google Jobs engine.
+ zh_Hans: 一个从 Google 招聘引擎检索职位名称、公司名称和描述的工具。
+ llm: A tool to retrieve job titles, company names and description from Google Jobs engine.
+parameters:
+ - name: query
+ type: string
+ required: true
+ label:
+ en_US: Query
+ zh_Hans: 询问
+ human_description:
+ en_US: Defines the query you want to search.
+ zh_Hans: 定义您要搜索的查询。
+ llm_description: Defines the search query you want to search.
+ form: llm
+ - name: result_type
+ type: select
+ required: true
+ options:
+ - value: text
+ label:
+ en_US: text
+ zh_Hans: 文本
+ - value: link
+ label:
+ en_US: link
+ zh_Hans: 链接
+ default: text
+ label:
+ en_US: Result type
+ zh_Hans: 结果类型
+ human_description:
+ en_US: used for selecting the result type, text or link
+ zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
+ form: form
+ - name: location
+ type: string
+ required: false
+ label:
+ en_US: Location
+ zh_Hans: 询问
+ human_description:
+ en_US: Defines from where you want the search to originate. (For example - New York)
+ zh_Hans: 定义您想要搜索的起始位置。 (例如 - 纽约)
+ llm_description: Defines from where you want the search to originate. (For example - New York)
+ form: llm
+ - name: gl
+ type: select
+ label:
+ en_US: Country
+ zh_Hans: 国家
+ required: false
+ human_description:
+ en_US: Defines the country of the search. Default is "US".
+ zh_Hans: 定义搜索的国家/地区。默认为“美国”。
+ llm_description: Defines the gl parameter of the Google search.
+ form: form
+ default: US
+ options:
+ - value: AR
+ label:
+ en_US: Argentina
+ zh_Hans: 阿根廷
+ pt_BR: Argentina
+ - value: AU
+ label:
+ en_US: Australia
+ zh_Hans: 澳大利亚
+ pt_BR: Australia
+ - value: AT
+ label:
+ en_US: Austria
+ zh_Hans: 奥地利
+ pt_BR: Austria
+ - value: BE
+ label:
+ en_US: Belgium
+ zh_Hans: 比利时
+ pt_BR: Belgium
+ - value: BR
+ label:
+ en_US: Brazil
+ zh_Hans: 巴西
+ pt_BR: Brazil
+ - value: CA
+ label:
+ en_US: Canada
+ zh_Hans: 加拿大
+ pt_BR: Canada
+ - value: CL
+ label:
+ en_US: Chile
+ zh_Hans: 智利
+ pt_BR: Chile
+ - value: CO
+ label:
+ en_US: Colombia
+ zh_Hans: 哥伦比亚
+ pt_BR: Colombia
+ - value: CN
+ label:
+ en_US: China
+ zh_Hans: 中国
+ pt_BR: China
+ - value: CZ
+ label:
+ en_US: Czech Republic
+ zh_Hans: 捷克共和国
+ pt_BR: Czech Republic
+ - value: DK
+ label:
+ en_US: Denmark
+ zh_Hans: 丹麦
+ pt_BR: Denmark
+ - value: FI
+ label:
+ en_US: Finland
+ zh_Hans: 芬兰
+ pt_BR: Finland
+ - value: FR
+ label:
+ en_US: France
+ zh_Hans: 法国
+ pt_BR: France
+ - value: DE
+ label:
+ en_US: Germany
+ zh_Hans: 德国
+ pt_BR: Germany
+ - value: HK
+ label:
+ en_US: Hong Kong
+ zh_Hans: 香港
+ pt_BR: Hong Kong
+ - value: IN
+ label:
+ en_US: India
+ zh_Hans: 印度
+ pt_BR: India
+ - value: ID
+ label:
+ en_US: Indonesia
+ zh_Hans: 印度尼西亚
+ pt_BR: Indonesia
+ - value: IT
+ label:
+ en_US: Italy
+ zh_Hans: 意大利
+ pt_BR: Italy
+ - value: JP
+ label:
+ en_US: Japan
+ zh_Hans: 日本
+ pt_BR: Japan
+ - value: KR
+ label:
+ en_US: Korea
+ zh_Hans: 韩国
+ pt_BR: Korea
+ - value: MY
+ label:
+ en_US: Malaysia
+ zh_Hans: 马来西亚
+ pt_BR: Malaysia
+ - value: MX
+ label:
+ en_US: Mexico
+ zh_Hans: 墨西哥
+ pt_BR: Mexico
+ - value: NL
+ label:
+ en_US: Netherlands
+ zh_Hans: 荷兰
+ pt_BR: Netherlands
+ - value: NZ
+ label:
+ en_US: New Zealand
+ zh_Hans: 新西兰
+ pt_BR: New Zealand
+ - value: NO
+ label:
+ en_US: Norway
+ zh_Hans: 挪威
+ pt_BR: Norway
+ - value: PH
+ label:
+ en_US: Philippines
+ zh_Hans: 菲律宾
+ pt_BR: Philippines
+ - value: PL
+ label:
+ en_US: Poland
+ zh_Hans: 波兰
+ pt_BR: Poland
+ - value: PT
+ label:
+ en_US: Portugal
+ zh_Hans: 葡萄牙
+ pt_BR: Portugal
+ - value: RU
+ label:
+ en_US: Russia
+ zh_Hans: 俄罗斯
+ pt_BR: Russia
+ - value: SA
+ label:
+ en_US: Saudi Arabia
+ zh_Hans: 沙特阿拉伯
+ pt_BR: Saudi Arabia
+ - value: SG
+ label:
+ en_US: Singapore
+ zh_Hans: 新加坡
+ pt_BR: Singapore
+ - value: ZA
+ label:
+ en_US: South Africa
+ zh_Hans: 南非
+ pt_BR: South Africa
+ - value: ES
+ label:
+ en_US: Spain
+ zh_Hans: 西班牙
+ pt_BR: Spain
+ - value: SE
+ label:
+ en_US: Sweden
+ zh_Hans: 瑞典
+ pt_BR: Sweden
+ - value: CH
+ label:
+ en_US: Switzerland
+ zh_Hans: 瑞士
+ pt_BR: Switzerland
+ - value: TW
+ label:
+ en_US: Taiwan
+ zh_Hans: 台湾
+ pt_BR: Taiwan
+ - value: TH
+ label:
+ en_US: Thailand
+ zh_Hans: 泰国
+ pt_BR: Thailand
+ - value: TR
+ label:
+ en_US: Turkey
+ zh_Hans: 土耳其
+ pt_BR: Turkey
+ - value: GB
+ label:
+ en_US: United Kingdom
+ zh_Hans: 英国
+ pt_BR: United Kingdom
+ - value: US
+ label:
+ en_US: United States
+ zh_Hans: 美国
+ pt_BR: United States
+ - name: hl
+ type: select
+ label:
+ en_US: Language
+ zh_Hans: 语言
+ human_description:
+ en_US: Defines the interface language of the search. Default is "en".
+ zh_Hans: 定义搜索的界面语言。默认为“en”。
+ required: false
+ default: en
+ form: form
+ options:
+ - value: ar
+ label:
+ en_US: Arabic
+ zh_Hans: 阿拉伯语
+ - value: bg
+ label:
+ en_US: Bulgarian
+ zh_Hans: 保加利亚语
+ - value: ca
+ label:
+ en_US: Catalan
+ zh_Hans: 加泰罗尼亚语
+ - value: zh-cn
+ label:
+ en_US: Chinese (Simplified)
+ zh_Hans: 中文(简体)
+ - value: zh-tw
+ label:
+ en_US: Chinese (Traditional)
+ zh_Hans: 中文(繁体)
+ - value: cs
+ label:
+ en_US: Czech
+ zh_Hans: 捷克语
+ - value: da
+ label:
+ en_US: Danish
+ zh_Hans: 丹麦语
+ - value: nl
+ label:
+ en_US: Dutch
+ zh_Hans: 荷兰语
+ - value: en
+ label:
+ en_US: English
+ zh_Hans: 英语
+ - value: et
+ label:
+ en_US: Estonian
+ zh_Hans: 爱沙尼亚语
+ - value: fi
+ label:
+ en_US: Finnish
+ zh_Hans: 芬兰语
+ - value: fr
+ label:
+ en_US: French
+ zh_Hans: 法语
+ - value: de
+ label:
+ en_US: German
+ zh_Hans: 德语
+ - value: el
+ label:
+ en_US: Greek
+ zh_Hans: 希腊语
+ - value: iw
+ label:
+ en_US: Hebrew
+ zh_Hans: 希伯来语
+ - value: hi
+ label:
+ en_US: Hindi
+ zh_Hans: 印地语
+ - value: hu
+ label:
+ en_US: Hungarian
+ zh_Hans: 匈牙利语
+ - value: id
+ label:
+ en_US: Indonesian
+ zh_Hans: 印尼语
+ - value: it
+ label:
+ en_US: Italian
+ zh_Hans: 意大利语
+ - value: jp
+ label:
+ en_US: Japanese
+ zh_Hans: 日语
+ - value: kn
+ label:
+ en_US: Kannada
+ zh_Hans: 卡纳达语
+ - value: ko
+ label:
+ en_US: Korean
+ zh_Hans: 韩语
+ - value: lv
+ label:
+ en_US: Latvian
+ zh_Hans: 拉脱维亚语
+ - value: lt
+ label:
+ en_US: Lithuanian
+ zh_Hans: 立陶宛语
+ - value: my
+ label:
+ en_US: Malay
+ zh_Hans: 马来语
+ - value: ml
+ label:
+ en_US: Malayalam
+ zh_Hans: 马拉雅拉姆语
+ - value: mr
+ label:
+ en_US: Marathi
+ zh_Hans: 马拉地语
+ - value: "no"
+ label:
+ en_US: Norwegian
+ zh_Hans: 挪威语
+ - value: pl
+ label:
+ en_US: Polish
+ zh_Hans: 波兰语
+ - value: pt-br
+ label:
+ en_US: Portuguese (Brazil)
+ zh_Hans: 葡萄牙语(巴西)
+ - value: pt-pt
+ label:
+ en_US: Portuguese (Portugal)
+ zh_Hans: 葡萄牙语(葡萄牙)
+ - value: pa
+ label:
+ en_US: Punjabi
+ zh_Hans: 旁遮普语
+ - value: ro
+ label:
+ en_US: Romanian
+ zh_Hans: 罗马尼亚语
+ - value: ru
+ label:
+ en_US: Russian
+ zh_Hans: 俄语
+ - value: sr
+ label:
+ en_US: Serbian
+ zh_Hans: 塞尔维亚语
+ - value: sk
+ label:
+ en_US: Slovak
+ zh_Hans: 斯洛伐克语
+ - value: sl
+ label:
+ en_US: Slovenian
+ zh_Hans: 斯洛文尼亚语
+ - value: es
+ label:
+ en_US: Spanish
+ zh_Hans: 西班牙语
+ - value: sv
+ label:
+ en_US: Swedish
+ zh_Hans: 瑞典语
+ - value: ta
+ label:
+ en_US: Tamil
+ zh_Hans: 泰米尔语
+ - value: te
+ label:
+ en_US: Telugu
+ zh_Hans: 泰卢固语
+ - value: th
+ label:
+ en_US: Thai
+ zh_Hans: 泰语
+ - value: tr
+ label:
+ en_US: Turkish
+ zh_Hans: 土耳其语
+ - value: uk
+ label:
+ en_US: Ukrainian
+ zh_Hans: 乌克兰语
+ - value: vi
+ label:
+ en_US: Vietnamese
+ zh_Hans: 越南语
+ - name: is_remote
+ type: select
+ label:
+ en_US: is_remote
+ zh_Hans: 很遥远
+ human_description:
+ en_US: Filter results based on the work arrangement. Set it to true to find jobs that offer work from home or remote work opportunities.
+ zh_Hans: 根据工作安排过滤结果。将其设置为 true 可查找提供在家工作或远程工作机会的工作。
+ required: false
+ form: form
+ options:
+ - value: true
+ label:
+ en_US: "true"
+ zh_Hans: "true"
+ - value: false
+ label:
+ en_US: "false"
+ zh_Hans: "false"
diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py
new file mode 100644
index 0000000000..d592dc25aa
--- /dev/null
+++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py
@@ -0,0 +1,92 @@
+from typing import Any, Union
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+SEARCH_API_URL = "https://www.searchapi.io/api/v1/search"
+
+class SearchAPI:
+ """
+ SearchAPI tool provider.
+ """
+
+ def __init__(self, api_key: str) -> None:
+ """Initialize SearchAPI tool provider."""
+ self.searchapi_api_key = api_key
+
+ def run(self, query: str, **kwargs: Any) -> str:
+ """Run query through SearchAPI and parse result."""
+ type = kwargs.get("result_type", "text")
+ return self._process_response(self.results(query, **kwargs), type=type)
+
+ def results(self, query: str, **kwargs: Any) -> dict:
+ """Run query through SearchAPI and return the raw result."""
+ params = self.get_params(query, **kwargs)
+ response = requests.get(
+ url=SEARCH_API_URL,
+ params=params,
+ headers={"Authorization": f"Bearer {self.searchapi_api_key}"},
+ )
+ response.raise_for_status()
+ return response.json()
+
+ def get_params(self, query: str, **kwargs: Any) -> dict[str, str]:
+ """Get parameters for SearchAPI."""
+ return {
+ "engine": "google_news",
+ "q": query,
+ **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+ }
+
+ @staticmethod
+ def _process_response(res: dict, type: str) -> str:
+ """Process response from SearchAPI."""
+ if "error" in res.keys():
+ raise ValueError(f"Got error from SearchApi: {res['error']}")
+
+ toret = ""
+ if type == "text":
+ if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys():
+ for item in res["organic_results"]:
+ toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n"
+ if "top_stories" in res.keys() and "title" in res["top_stories"][0].keys():
+ for item in res["top_stories"]:
+ toret += "title: " + item["title"] + "\n" + "link: " + item["link"] + "\n"
+ if toret == "":
+ toret = "No good search result found"
+
+ elif type == "link":
+ if "organic_results" in res.keys() and "title" in res["organic_results"][0].keys():
+ for item in res["organic_results"]:
+ toret += f"[{item['title']}]({item['link']})\n"
+ elif "top_stories" in res.keys() and "title" in res["top_stories"][0].keys():
+ for item in res["top_stories"]:
+ toret += f"[{item['title']}]({item['link']})\n"
+ else:
+ toret = "No good search result found"
+ return toret
+
+class GoogleNewsTool(BuiltinTool):
+ def _invoke(self,
+ user_id: str,
+ tool_parameters: dict[str, Any],
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ """
+ Invoke the SearchApi tool.
+ """
+ query = tool_parameters['query']
+ result_type = tool_parameters['result_type']
+ num = tool_parameters.get("num", 10)
+ google_domain = tool_parameters.get("google_domain", "google.com")
+ gl = tool_parameters.get("gl", "us")
+ hl = tool_parameters.get("hl", "en")
+ location = tool_parameters.get("location", None)
+
+ api_key = self.runtime.credentials['searchapi_api_key']
+ result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location)
+
+ if result_type == 'text':
+ return self.create_text_message(text=result)
+ return self.create_link_message(link=result)
diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.yaml b/api/core/tools/provider/builtin/searchapi/tools/google_news.yaml
new file mode 100644
index 0000000000..b0212952e6
--- /dev/null
+++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.yaml
@@ -0,0 +1,482 @@
+identity:
+ name: google_news_api
+ author: SearchApi
+ label:
+ en_US: Google News API
+ zh_Hans: Google News API
+description:
+ human:
+ en_US: A tool to retrieve organic search results snippets and links from Google News engine.
+ zh_Hans: 一种从 Google 新闻引擎检索有机搜索结果片段和链接的工具。
+ llm: A tool to retrieve organic search results snippets and links from Google News engine.
+parameters:
+ - name: query
+ type: string
+ required: true
+ label:
+ en_US: Query
+ zh_Hans: 询问
+ human_description:
+ en_US: Defines the query you want to search.
+ zh_Hans: 定义您要搜索的查询。
+ llm_description: Defines the search query you want to search.
+ form: llm
+ - name: result_type
+ type: select
+ required: true
+ options:
+ - value: text
+ label:
+ en_US: text
+ zh_Hans: 文本
+ - value: link
+ label:
+ en_US: link
+ zh_Hans: 链接
+ default: text
+ label:
+ en_US: Result type
+ zh_Hans: 结果类型
+ human_description:
+ en_US: used for selecting the result type, text or link.
+ zh_Hans: 用于选择结果类型,使用文本还是链接进行展示。
+ form: form
+ - name: location
+ type: string
+ required: false
+ label:
+ en_US: Location
+ zh_Hans: 询问
+ human_description:
+ en_US: Defines from where you want the search to originate. (For example - New York)
+ zh_Hans: 定义您想要搜索的起始位置。 (例如 - 纽约)
+ llm_description: Defines from where you want the search to originate. (For example - New York)
+ form: llm
+ - name: gl
+ type: select
+ label:
+ en_US: Country
+ zh_Hans: 国家
+ required: false
+ human_description:
+ en_US: Defines the country of the search. Default is "US".
+ zh_Hans: 定义搜索的国家/地区。默认为“美国”。
+ llm_description: Defines the gl parameter of the Google search.
+ form: form
+ default: US
+ options:
+ - value: AR
+ label:
+ en_US: Argentina
+ zh_Hans: 阿根廷
+ pt_BR: Argentina
+ - value: AU
+ label:
+ en_US: Australia
+ zh_Hans: 澳大利亚
+ pt_BR: Australia
+ - value: AT
+ label:
+ en_US: Austria
+ zh_Hans: 奥地利
+ pt_BR: Austria
+ - value: BE
+ label:
+ en_US: Belgium
+ zh_Hans: 比利时
+ pt_BR: Belgium
+ - value: BR
+ label:
+ en_US: Brazil
+ zh_Hans: 巴西
+ pt_BR: Brazil
+ - value: CA
+ label:
+ en_US: Canada
+ zh_Hans: 加拿大
+ pt_BR: Canada
+ - value: CL
+ label:
+ en_US: Chile
+ zh_Hans: 智利
+ pt_BR: Chile
+ - value: CO
+ label:
+ en_US: Colombia
+ zh_Hans: 哥伦比亚
+ pt_BR: Colombia
+ - value: CN
+ label:
+ en_US: China
+ zh_Hans: 中国
+ pt_BR: China
+ - value: CZ
+ label:
+ en_US: Czech Republic
+ zh_Hans: 捷克共和国
+ pt_BR: Czech Republic
+ - value: DK
+ label:
+ en_US: Denmark
+ zh_Hans: 丹麦
+ pt_BR: Denmark
+ - value: FI
+ label:
+ en_US: Finland
+ zh_Hans: 芬兰
+ pt_BR: Finland
+ - value: FR
+ label:
+ en_US: France
+ zh_Hans: 法国
+ pt_BR: France
+ - value: DE
+ label:
+ en_US: Germany
+ zh_Hans: 德国
+ pt_BR: Germany
+ - value: HK
+ label:
+ en_US: Hong Kong
+ zh_Hans: 香港
+ pt_BR: Hong Kong
+ - value: IN
+ label:
+ en_US: India
+ zh_Hans: 印度
+ pt_BR: India
+ - value: ID
+ label:
+ en_US: Indonesia
+ zh_Hans: 印度尼西亚
+ pt_BR: Indonesia
+ - value: IT
+ label:
+ en_US: Italy
+ zh_Hans: 意大利
+ pt_BR: Italy
+ - value: JP
+ label:
+ en_US: Japan
+ zh_Hans: 日本
+ pt_BR: Japan
+ - value: KR
+ label:
+ en_US: Korea
+ zh_Hans: 韩国
+ pt_BR: Korea
+ - value: MY
+ label:
+ en_US: Malaysia
+ zh_Hans: 马来西亚
+ pt_BR: Malaysia
+ - value: MX
+ label:
+ en_US: Mexico
+ zh_Hans: 墨西哥
+ pt_BR: Mexico
+ - value: NL
+ label:
+ en_US: Netherlands
+ zh_Hans: 荷兰
+ pt_BR: Netherlands
+ - value: NZ
+ label:
+ en_US: New Zealand
+ zh_Hans: 新西兰
+ pt_BR: New Zealand
+ - value: NO
+ label:
+ en_US: Norway
+ zh_Hans: 挪威
+ pt_BR: Norway
+ - value: PH
+ label:
+ en_US: Philippines
+ zh_Hans: 菲律宾
+ pt_BR: Philippines
+ - value: PL
+ label:
+ en_US: Poland
+ zh_Hans: 波兰
+ pt_BR: Poland
+ - value: PT
+ label:
+ en_US: Portugal
+ zh_Hans: 葡萄牙
+ pt_BR: Portugal
+ - value: RU
+ label:
+ en_US: Russia
+ zh_Hans: 俄罗斯
+ pt_BR: Russia
+ - value: SA
+ label:
+ en_US: Saudi Arabia
+ zh_Hans: 沙特阿拉伯
+ pt_BR: Saudi Arabia
+ - value: SG
+ label:
+ en_US: Singapore
+ zh_Hans: 新加坡
+ pt_BR: Singapore
+ - value: ZA
+ label:
+ en_US: South Africa
+ zh_Hans: 南非
+ pt_BR: South Africa
+ - value: ES
+ label:
+ en_US: Spain
+ zh_Hans: 西班牙
+ pt_BR: Spain
+ - value: SE
+ label:
+ en_US: Sweden
+ zh_Hans: 瑞典
+ pt_BR: Sweden
+ - value: CH
+ label:
+ en_US: Switzerland
+ zh_Hans: 瑞士
+ pt_BR: Switzerland
+ - value: TW
+ label:
+ en_US: Taiwan
+ zh_Hans: 台湾
+ pt_BR: Taiwan
+ - value: TH
+ label:
+ en_US: Thailand
+ zh_Hans: 泰国
+ pt_BR: Thailand
+ - value: TR
+ label:
+ en_US: Turkey
+ zh_Hans: 土耳其
+ pt_BR: Turkey
+ - value: GB
+ label:
+ en_US: United Kingdom
+ zh_Hans: 英国
+ pt_BR: United Kingdom
+ - value: US
+ label:
+ en_US: United States
+ zh_Hans: 美国
+ pt_BR: United States
+ - name: hl
+ type: select
+ label:
+ en_US: Language
+ zh_Hans: 语言
+ human_description:
+ en_US: Defines the interface language of the search. Default is "en".
+ zh_Hans: 定义搜索的界面语言。默认为“en”。
+ required: false
+ default: en
+ form: form
+ options:
+ - value: ar
+ label:
+ en_US: Arabic
+ zh_Hans: 阿拉伯语
+ - value: bg
+ label:
+ en_US: Bulgarian
+ zh_Hans: 保加利亚语
+ - value: ca
+ label:
+ en_US: Catalan
+ zh_Hans: 加泰罗尼亚语
+ - value: zh-cn
+ label:
+ en_US: Chinese (Simplified)
+ zh_Hans: 中文(简体)
+ - value: zh-tw
+ label:
+ en_US: Chinese (Traditional)
+ zh_Hans: 中文(繁体)
+ - value: cs
+ label:
+ en_US: Czech
+ zh_Hans: 捷克语
+ - value: da
+ label:
+ en_US: Danish
+ zh_Hans: 丹麦语
+ - value: nl
+ label:
+ en_US: Dutch
+ zh_Hans: 荷兰语
+ - value: en
+ label:
+ en_US: English
+ zh_Hans: 英语
+ - value: et
+ label:
+ en_US: Estonian
+ zh_Hans: 爱沙尼亚语
+ - value: fi
+ label:
+ en_US: Finnish
+ zh_Hans: 芬兰语
+ - value: fr
+ label:
+ en_US: French
+ zh_Hans: 法语
+ - value: de
+ label:
+ en_US: German
+ zh_Hans: 德语
+ - value: el
+ label:
+ en_US: Greek
+ zh_Hans: 希腊语
+ - value: iw
+ label:
+ en_US: Hebrew
+ zh_Hans: 希伯来语
+ - value: hi
+ label:
+ en_US: Hindi
+ zh_Hans: 印地语
+ - value: hu
+ label:
+ en_US: Hungarian
+ zh_Hans: 匈牙利语
+ - value: id
+ label:
+ en_US: Indonesian
+ zh_Hans: 印尼语
+ - value: it
+ label:
+ en_US: Italian
+ zh_Hans: 意大利语
+ - value: jp
+ label:
+ en_US: Japanese
+ zh_Hans: 日语
+ - value: kn
+ label:
+ en_US: Kannada
+ zh_Hans: 卡纳达语
+ - value: ko
+ label:
+ en_US: Korean
+ zh_Hans: 韩语
+ - value: lv
+ label:
+ en_US: Latvian
+ zh_Hans: 拉脱维亚语
+ - value: lt
+ label:
+ en_US: Lithuanian
+ zh_Hans: 立陶宛语
+ - value: my
+ label:
+ en_US: Malay
+ zh_Hans: 马来语
+ - value: ml
+ label:
+ en_US: Malayalam
+ zh_Hans: 马拉雅拉姆语
+ - value: mr
+ label:
+ en_US: Marathi
+ zh_Hans: 马拉地语
+ - value: "no"
+ label:
+ en_US: Norwegian
+ zh_Hans: 挪威语
+ - value: pl
+ label:
+ en_US: Polish
+ zh_Hans: 波兰语
+ - value: pt-br
+ label:
+ en_US: Portuguese (Brazil)
+ zh_Hans: 葡萄牙语(巴西)
+ - value: pt-pt
+ label:
+ en_US: Portuguese (Portugal)
+ zh_Hans: 葡萄牙语(葡萄牙)
+ - value: pa
+ label:
+ en_US: Punjabi
+ zh_Hans: 旁遮普语
+ - value: ro
+ label:
+ en_US: Romanian
+ zh_Hans: 罗马尼亚语
+ - value: ru
+ label:
+ en_US: Russian
+ zh_Hans: 俄语
+ - value: sr
+ label:
+ en_US: Serbian
+ zh_Hans: 塞尔维亚语
+ - value: sk
+ label:
+ en_US: Slovak
+ zh_Hans: 斯洛伐克语
+ - value: sl
+ label:
+ en_US: Slovenian
+ zh_Hans: 斯洛文尼亚语
+ - value: es
+ label:
+ en_US: Spanish
+ zh_Hans: 西班牙语
+ - value: sv
+ label:
+ en_US: Swedish
+ zh_Hans: 瑞典语
+ - value: ta
+ label:
+ en_US: Tamil
+ zh_Hans: 泰米尔语
+ - value: te
+ label:
+ en_US: Telugu
+ zh_Hans: 泰卢固语
+ - value: th
+ label:
+ en_US: Thai
+ zh_Hans: 泰语
+ - value: tr
+ label:
+ en_US: Turkish
+ zh_Hans: 土耳其语
+ - value: uk
+ label:
+ en_US: Ukrainian
+ zh_Hans: 乌克兰语
+ - value: vi
+ label:
+ en_US: Vietnamese
+ zh_Hans: 越南语
+ - name: google_domain
+ type: string
+ required: false
+ label:
+ en_US: google_domain
+ zh_Hans: google_domain
+ human_description:
+ en_US: Defines the Google domain of the search. Default is "google.com".
+ zh_Hans: 定义搜索的 Google 域。默认为“google.com”。
+ llm_description: Defines Google domain in which you want to search.
+ form: llm
+ - name: num
+ type: number
+ required: false
+ label:
+ en_US: num
+ zh_Hans: num
+ human_description:
+ en_US: Specifies the number of results to display per page. Default is 10. Max number - 100, min - 1.
+ zh_Hans: 指定每页显示的结果数。默认值为 10。最大数量 - 100,最小数量 - 1。
+ pt_BR: Specifies the number of results to display per page. Default is 10. Max number - 100, min - 1.
+ llm_description: Specifies the num of results to display per page.
+ form: llm
diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py
new file mode 100644
index 0000000000..6345b33801
--- /dev/null
+++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py
@@ -0,0 +1,72 @@
+from typing import Any, Union
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+SEARCH_API_URL = "https://www.searchapi.io/api/v1/search"
+
+class SearchAPI:
+ """
+ SearchAPI tool provider.
+ """
+
+ def __init__(self, api_key: str) -> None:
+ """Initialize SearchAPI tool provider."""
+ self.searchapi_api_key = api_key
+
+ def run(self, video_id: str, language: str, **kwargs: Any) -> str:
+ """Run video_id through SearchAPI and parse result."""
+ return self._process_response(self.results(video_id, language, **kwargs))
+
+ def results(self, video_id: str, language: str, **kwargs: Any) -> dict:
+ """Run video_id through SearchAPI and return the raw result."""
+ params = self.get_params(video_id, language, **kwargs)
+ response = requests.get(
+ url=SEARCH_API_URL,
+ params=params,
+ headers={"Authorization": f"Bearer {self.searchapi_api_key}"},
+ )
+ response.raise_for_status()
+ return response.json()
+
+ def get_params(self, video_id: str, language: str, **kwargs: Any) -> dict[str, str]:
+ """Get parameters for SearchAPI."""
+ return {
+ "engine": "youtube_transcripts",
+ "video_id": video_id,
+ "lang": language if language else "en",
+ **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+ }
+
+ @staticmethod
+ def _process_response(res: dict) -> str:
+ """Process response from SearchAPI."""
+ if "error" in res.keys():
+ raise ValueError(f"Got error from SearchApi: {res['error']}")
+
+ toret = ""
+ if "transcripts" in res.keys() and "text" in res["transcripts"][0].keys():
+ for item in res["transcripts"]:
+ toret += item["text"] + " "
+ if toret == "":
+ toret = "No good search result found"
+
+ return toret
+
+class YoutubeTranscriptsTool(BuiltinTool):
+ def _invoke(self,
+ user_id: str,
+ tool_parameters: dict[str, Any],
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ """
+ Invoke the SearchApi tool.
+ """
+ video_id = tool_parameters['video_id']
+ language = tool_parameters.get('language', "en")
+
+ api_key = self.runtime.credentials['searchapi_api_key']
+ result = SearchAPI(api_key).run(video_id, language=language)
+
+ return self.create_text_message(text=result)
diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.yaml b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.yaml
new file mode 100644
index 0000000000..8bdcd6bb93
--- /dev/null
+++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.yaml
@@ -0,0 +1,34 @@
+identity:
+ name: youtube_transcripts_api
+ author: SearchApi
+ label:
+ en_US: YouTube Transcripts API
+ zh_Hans: YouTube 脚本 API
+description:
+ human:
+ en_US: A tool to retrieve transcripts from the specific YouTube video.
+ zh_Hans: 一种从特定 YouTube 视频检索文字记录的工具。
+ llm: A tool to retrieve transcripts from the specific YouTube video.
+parameters:
+ - name: video_id
+ type: string
+ required: true
+ label:
+ en_US: video_id
+ zh_Hans: 视频ID
+ human_description:
+ en_US: Used to define the video you want to search. You can find the video id's in YouTube page that appears in URL. For example - https://www.youtube.com/watch?v=video_id.
+ zh_Hans: 用于定义要搜索的视频。您可以在 URL 中显示的 YouTube 页面中找到视频 ID。例如 - https://www.youtube.com/watch?v=video_id。
+ llm_description: Used to define the video you want to search.
+ form: llm
+ - name: language
+ type: string
+ required: false
+ label:
+ en_US: language
+ zh_Hans: 语言
+ human_description:
+ en_US: Used to set the language for transcripts. The default value is "en". You can find all supported languages in SearchApi documentation.
+ zh_Hans: 用于设置成绩单的语言。默认值为“en”。您可以在 SearchApi 文档中找到所有支持的语言。
+ llm_description: Used to set the language for transcripts.
+ form: llm
diff --git a/api/core/tools/provider/builtin/searxng/searxng.py b/api/core/tools/provider/builtin/searxng/searxng.py
index 8046056093..24b94b5ca4 100644
--- a/api/core/tools/provider/builtin/searxng/searxng.py
+++ b/api/core/tools/provider/builtin/searxng/searxng.py
@@ -9,7 +9,7 @@ class SearXNGProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
SearXNGSearchTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -22,4 +22,4 @@ class SearXNGProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
diff --git a/api/core/tools/provider/builtin/searxng/searxng.yaml b/api/core/tools/provider/builtin/searxng/searxng.yaml
index c8c713cf04..64bd428280 100644
--- a/api/core/tools/provider/builtin/searxng/searxng.yaml
+++ b/api/core/tools/provider/builtin/searxng/searxng.yaml
@@ -8,6 +8,9 @@ identity:
en_US: A free internet metasearch engine.
zh_Hans: 开源互联网元搜索引擎
icon: icon.svg
+ tags:
+ - search
+ - productivity
credentials_for_provider:
searxng_base_url:
type: secret-input
diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py
index cbc5ab435a..50f04760a7 100644
--- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py
+++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py
@@ -121,4 +121,5 @@ class SearXNGSearchTool(BuiltinTool):
query=query,
search_type=search_type,
result_type=result_type,
- topK=num_results)
\ No newline at end of file
+ topK=num_results
+ )
diff --git a/api/core/tools/provider/builtin/slack/slack.yaml b/api/core/tools/provider/builtin/slack/slack.yaml
index 7278793900..1070ffbf03 100644
--- a/api/core/tools/provider/builtin/slack/slack.yaml
+++ b/api/core/tools/provider/builtin/slack/slack.yaml
@@ -10,4 +10,7 @@ identity:
zh_Hans: Slack Webhook
pt_BR: Slack Webhook
icon: icon.svg
+ tags:
+ - social
+ - productivity
credentials_for_provider:
diff --git a/api/core/tools/provider/builtin/spark/spark.yaml b/api/core/tools/provider/builtin/spark/spark.yaml
index f2b9c89e96..fa1543443a 100644
--- a/api/core/tools/provider/builtin/spark/spark.yaml
+++ b/api/core/tools/provider/builtin/spark/spark.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: 讯飞星火平台工具
pt_BR: Pacote de Ferramentas da Plataforma Spark
icon: icon.svg
+ tags:
+ - image
credentials_for_provider:
APPID:
type: secret-input
diff --git a/api/core/tools/provider/builtin/stability/stability.py b/api/core/tools/provider/builtin/stability/stability.py
index d00c3ecf00..b31d786178 100644
--- a/api/core/tools/provider/builtin/stability/stability.py
+++ b/api/core/tools/provider/builtin/stability/stability.py
@@ -12,4 +12,4 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz
"""
This method is responsible for validating the credentials.
"""
- self.sd_validate_credentials(credentials)
\ No newline at end of file
+ self.sd_validate_credentials(credentials)
diff --git a/api/core/tools/provider/builtin/stability/stability.yaml b/api/core/tools/provider/builtin/stability/stability.yaml
index d8369a4c03..c3e01c1e31 100644
--- a/api/core/tools/provider/builtin/stability/stability.yaml
+++ b/api/core/tools/provider/builtin/stability/stability.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: 通过生成式 AI 激活人类的潜力
pt_BR: Activating humanity's potential through generative AI
icon: icon.svg
+ tags:
+ - image
credentials_for_provider:
api_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py
index 5748e8d4e2..317d705f7c 100644
--- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py
+++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py
@@ -9,9 +9,10 @@ class StableDiffusionProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
StableDiffusionTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).validate_models()
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml
index e1161da5bb..9b3c804f72 100644
--- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml
+++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: Stable Diffusion 是一个可以在本地部署的图片生成的工具。
pt_BR: Stable Diffusion is a tool for generating images which can be deployed locally.
icon: icon.png
+ tags:
+ - image
credentials_for_provider:
base_url:
type: secret-input
diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py
index 25206da0bc..0c5ebc23ac 100644
--- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py
+++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py
@@ -13,49 +13,76 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter,
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.tool.builtin_tool import BuiltinTool
+# All commented out parameters default to null
DRAW_TEXT_OPTIONS = {
+ # Prompts
"prompt": "",
"negative_prompt": "",
+ # "styles": [],
+ # Seeds
"seed": -1,
"subseed": -1,
"subseed_strength": 0,
"seed_resize_from_h": -1,
- 'sampler_index': 'DPM++ SDE Karras',
"seed_resize_from_w": -1,
+
+ # Samplers
+ # "sampler_name": "DPM++ 2M",
+ # "scheduler": "",
+ # "sampler_index": "Automatic",
+
+ # Latent Space Options
"batch_size": 1,
"n_iter": 1,
"steps": 10,
"cfg_scale": 7,
- "width": 1024,
- "height": 1024,
- "restore_faces": False,
+ "width": 512,
+ "height": 512,
+ # "restore_faces": True,
+ # "tiling": True,
"do_not_save_samples": False,
"do_not_save_grid": False,
- "eta": 0,
- "denoising_strength": 0,
- "s_min_uncond": 0,
- "s_churn": 0,
- "s_tmax": 0,
- "s_tmin": 0,
- "s_noise": 0,
+ # "eta": 0,
+ # "denoising_strength": 0.75,
+ # "s_min_uncond": 0,
+ # "s_churn": 0,
+ # "s_tmax": 0,
+ # "s_tmin": 0,
+ # "s_noise": 0,
"override_settings": {},
"override_settings_restore_afterwards": True,
+ # Refinement Options
+ "refiner_checkpoint": "",
"refiner_switch_at": 0,
"disable_extra_networks": False,
- "comments": {},
+ # "firstpass_image": "",
+ # "comments": "",
+ # High-Resolution Options
"enable_hr": False,
"firstphase_width": 0,
"firstphase_height": 0,
"hr_scale": 2,
+ # "hr_upscaler": "",
"hr_second_pass_steps": 0,
"hr_resize_x": 0,
"hr_resize_y": 0,
+ # "hr_checkpoint_name": "",
+ # "hr_sampler_name": "",
+ # "hr_scheduler": "",
"hr_prompt": "",
"hr_negative_prompt": "",
+ # Task Options
+ # "force_task_id": "",
+
+ # Script Options
+ # "script_name": "",
"script_args": [],
+ # Output Options
"send_images": True,
"save_images": False,
- "alwayson_scripts": {}
+ "alwayson_scripts": {},
+ # "infotext": "",
+
}
@@ -70,7 +97,7 @@ class StableDiffusionTool(BuiltinTool):
if not base_url:
return self.create_text_message('Please input base_url')
- if 'model' in tool_parameters and tool_parameters['model']:
+ if tool_parameters.get('model'):
self.runtime.credentials['model'] = tool_parameters['model']
model = self.runtime.credentials.get('model', None)
@@ -88,60 +115,15 @@ class StableDiffusionTool(BuiltinTool):
except Exception as e:
raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model')
-
- # prompt
- prompt = tool_parameters.get('prompt', '')
- if not prompt:
- return self.create_text_message('Please input prompt')
-
- # get negative prompt
- negative_prompt = tool_parameters.get('negative_prompt', '')
-
- # get size
- width = tool_parameters.get('width', 1024)
- height = tool_parameters.get('height', 1024)
-
- # get steps
- steps = tool_parameters.get('steps', 1)
-
- # get lora
- lora = tool_parameters.get('lora', '')
-
- # get image id
+ # get image id and image variable
image_id = tool_parameters.get('image_id', '')
- if image_id.strip():
- image_variable = self.get_default_image_variable()
- if image_variable:
- image_binary = self.get_variable_file(image_variable.name)
- if not image_binary:
- return self.create_text_message('Image not found, please request user to generate image firstly.')
-
- # convert image to RGB
- image = Image.open(io.BytesIO(image_binary))
- image = image.convert("RGB")
- buffer = io.BytesIO()
- image.save(buffer, format="PNG")
- image_binary = buffer.getvalue()
- image.close()
+ image_variable = self.get_default_image_variable()
+ # Return text2img if there's no image ID or no image variable
+ if not image_id or not image_variable:
+ return self.text2img(base_url=base_url,tool_parameters=tool_parameters)
- return self.img2img(base_url=base_url,
- lora=lora,
- image_binary=image_binary,
- prompt=prompt,
- negative_prompt=negative_prompt,
- width=width,
- height=height,
- steps=steps,
- model=model)
-
- return self.text2img(base_url=base_url,
- lora=lora,
- prompt=prompt,
- negative_prompt=negative_prompt,
- width=width,
- height=height,
- steps=steps,
- model=model)
+ # Proceed with image-to-image generation
+ return self.img2img(base_url=base_url,tool_parameters=tool_parameters)
def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
@@ -197,35 +179,67 @@ class StableDiffusionTool(BuiltinTool):
except Exception as e:
return []
- def img2img(self, base_url: str, lora: str, image_binary: bytes,
- prompt: str, negative_prompt: str,
- width: int, height: int, steps: int, model: str) \
+ def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
generate image
"""
- draw_options = {
- "init_images": [b64encode(image_binary).decode('utf-8')],
- "prompt": "",
- "negative_prompt": negative_prompt,
- "denoising_strength": 0.9,
- "width": width,
- "height": height,
- "cfg_scale": 7,
- "sampler_name": "Euler a",
- "restore_faces": False,
- "steps": steps,
- "script_args": ["outpainting mk2"],
- "override_settings": {"sd_model_checkpoint": model}
- }
+ # Fetch the binary data of the image
+ image_variable = self.get_default_image_variable()
+ image_binary = self.get_variable_file(image_variable.name)
+ if not image_binary:
+ return self.create_text_message('Image not found, please request user to generate image firstly.')
+
+ # Convert image to RGB and save as PNG
+ try:
+ with Image.open(io.BytesIO(image_binary)) as image:
+ with io.BytesIO() as buffer:
+ image.convert("RGB").save(buffer, format="PNG")
+ image_binary = buffer.getvalue()
+ except Exception as e:
+ return self.create_text_message(f"Failed to process the image: {str(e)}")
+
+ # copy draw options
+ draw_options = deepcopy(DRAW_TEXT_OPTIONS)
+ # set image options
+ model = tool_parameters.get('model', '')
+ draw_options_image = {
+ "init_images": [b64encode(image_binary).decode('utf-8')],
+ "denoising_strength": 0.9,
+ "restore_faces": False,
+ "script_args": [],
+ "override_settings": {"sd_model_checkpoint": model},
+ "resize_mode":0,
+ "image_cfg_scale": 0,
+ # "mask": None,
+ "mask_blur_x": 4,
+ "mask_blur_y": 4,
+ "mask_blur": 0,
+ "mask_round": True,
+ "inpainting_fill": 0,
+ "inpaint_full_res": True,
+ "inpaint_full_res_padding": 0,
+ "inpainting_mask_invert": 0,
+ "initial_noise_multiplier": 0,
+ # "latent_mask": None,
+ "include_init_images": True,
+ }
+ # update key and values
+ draw_options.update(draw_options_image)
+ draw_options.update(tool_parameters)
+
+ # get prompt lora model
+ prompt = tool_parameters.get('prompt', '')
+ lora = tool_parameters.get('lora', '')
+ model = tool_parameters.get('model', '')
if lora:
draw_options['prompt'] = f'{lora},{prompt}'
else:
draw_options['prompt'] = prompt
try:
- url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
+ url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200:
return self.create_text_message('Failed to generate image')
@@ -239,24 +253,24 @@ class StableDiffusionTool(BuiltinTool):
except Exception as e:
return self.create_text_message('Failed to generate image')
- def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int, model: str) \
+ def text2img(self, base_url: str, tool_parameters: dict[str, Any]) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
generate image
"""
# copy draw options
draw_options = deepcopy(DRAW_TEXT_OPTIONS)
-
+ draw_options.update(tool_parameters)
+ # get prompt lora model
+ prompt = tool_parameters.get('prompt', '')
+ lora = tool_parameters.get('lora', '')
+ model = tool_parameters.get('model', '')
if lora:
draw_options['prompt'] = f'{lora},{prompt}'
else:
draw_options['prompt'] = prompt
-
- draw_options['width'] = width
- draw_options['height'] = height
- draw_options['steps'] = steps
- draw_options['negative_prompt'] = negative_prompt
draw_options['override_settings']['sd_model_checkpoint'] = model
+
try:
url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
diff --git a/api/core/tools/provider/builtin/stackexchange/stackexchange.py b/api/core/tools/provider/builtin/stackexchange/stackexchange.py
index fab543c580..de64c84997 100644
--- a/api/core/tools/provider/builtin/stackexchange/stackexchange.py
+++ b/api/core/tools/provider/builtin/stackexchange/stackexchange.py
@@ -7,7 +7,7 @@ class StackExchangeProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
SearchStackExQuestionsTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -22,4 +22,5 @@ class StackExchangeProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/stackexchange/stackexchange.yaml b/api/core/tools/provider/builtin/stackexchange/stackexchange.yaml
index b2c8fe296f..ad64b9595a 100644
--- a/api/core/tools/provider/builtin/stackexchange/stackexchange.yaml
+++ b/api/core/tools/provider/builtin/stackexchange/stackexchange.yaml
@@ -8,3 +8,6 @@ identity:
en_US: Access questions and answers from the Stack Exchange and its sub-sites.
zh_Hans: 从Stack Exchange和其子论坛获取问题和答案。
icon: icon.svg
+ tags:
+ - search
+ - utilities
diff --git a/api/core/tools/provider/builtin/tavily/tavily.py b/api/core/tools/provider/builtin/tavily/tavily.py
index 575d9268b9..e376d99d6b 100644
--- a/api/core/tools/provider/builtin/tavily/tavily.py
+++ b/api/core/tools/provider/builtin/tavily/tavily.py
@@ -9,7 +9,7 @@ class TavilyProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
TavilySearchTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -26,4 +26,5 @@ class TavilyProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/tavily/tavily.yaml b/api/core/tools/provider/builtin/tavily/tavily.yaml
index 50826e37b3..7b25a81848 100644
--- a/api/core/tools/provider/builtin/tavily/tavily.yaml
+++ b/api/core/tools/provider/builtin/tavily/tavily.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: Tavily
pt_BR: Tavily
icon: icon.png
+ tags:
+ - search
credentials_for_provider:
tavily_api_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/time/time.py b/api/core/tools/provider/builtin/time/time.py
index 0d3285f495..833ae194ef 100644
--- a/api/core/tools/provider/builtin/time/time.py
+++ b/api/core/tools/provider/builtin/time/time.py
@@ -13,4 +13,5 @@ class WikiPediaProvider(BuiltinToolProviderController):
tool_parameters={},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/time/time.yaml b/api/core/tools/provider/builtin/time/time.yaml
index 8ea67cded5..1278939df5 100644
--- a/api/core/tools/provider/builtin/time/time.yaml
+++ b/api/core/tools/provider/builtin/time/time.yaml
@@ -10,4 +10,6 @@ identity:
zh_Hans: 一个用于获取当前时间的工具。
pt_BR: A tool for getting the current time.
icon: icon.svg
+ tags:
+ - utilities
credentials_for_provider:
diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py
index 8722274565..90c01665e6 100644
--- a/api/core/tools/provider/builtin/time/tools/current_time.py
+++ b/api/core/tools/provider/builtin/time/tools/current_time.py
@@ -17,11 +17,12 @@ class CurrentTimeTool(BuiltinTool):
"""
# get timezone
tz = tool_parameters.get('timezone', 'UTC')
+ fm = tool_parameters.get('format') or '%Y-%m-%d %H:%M:%S %Z'
if tz == 'UTC':
- return self.create_text_message(f'{datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")}')
-
+ return self.create_text_message(f'{datetime.now(timezone.utc).strftime(fm)}')
+
try:
tz = pytz_timezone(tz)
except:
return self.create_text_message(f'Invalid timezone: {tz}')
- return self.create_text_message(f'{datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S %Z")}')
\ No newline at end of file
+ return self.create_text_message(f'{datetime.now(tz).strftime(fm)}')
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/time/tools/current_time.yaml b/api/core/tools/provider/builtin/time/tools/current_time.yaml
index f0b5f53bd8..dbdd39e223 100644
--- a/api/core/tools/provider/builtin/time/tools/current_time.yaml
+++ b/api/core/tools/provider/builtin/time/tools/current_time.yaml
@@ -12,6 +12,19 @@ description:
pt_BR: A tool for getting the current time.
llm: A tool for getting the current time.
parameters:
+ - name: format
+ type: string
+ required: false
+ label:
+ en_US: Format
+ zh_Hans: 格式
+ pt_BR: Format
+ human_description:
+ en_US: Time format in strftime standard.
+ zh_Hans: strftime 标准的时间格式。
+ pt_BR: Time format in strftime standard.
+ form: form
+ default: "%Y-%m-%d %H:%M:%S"
- name: timezone
type: select
required: false
@@ -46,6 +59,11 @@ parameters:
en_US: America/Chicago
zh_Hans: 美洲/芝加哥
pt_BR: America/Chicago
+ - value: America/Sao_Paulo
+ label:
+ en_US: America/Sao_Paulo
+ zh_Hans: 美洲/圣保罗
+ pt_BR: América/São Paulo
- value: Asia/Shanghai
label:
en_US: Asia/Shanghai
diff --git a/api/core/tools/provider/builtin/trello/trello.py b/api/core/tools/provider/builtin/trello/trello.py
index d27115d246..84ecd20803 100644
--- a/api/core/tools/provider/builtin/trello/trello.py
+++ b/api/core/tools/provider/builtin/trello/trello.py
@@ -31,4 +31,5 @@ class TrelloProvider(BuiltinToolProviderController):
raise ToolProviderCredentialValidationError("Error validating Trello credentials")
except requests.exceptions.RequestException as e:
# Handle other exceptions, such as connection errors
- raise ToolProviderCredentialValidationError("Error validating Trello credentials")
\ No newline at end of file
+ raise ToolProviderCredentialValidationError("Error validating Trello credentials")
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/trello/trello.yaml b/api/core/tools/provider/builtin/trello/trello.yaml
index e1228c16be..49c9f4f9a1 100644
--- a/api/core/tools/provider/builtin/trello/trello.yaml
+++ b/api/core/tools/provider/builtin/trello/trello.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: "Trello: 一个用于组织工作和生活的视觉工具。"
pt_BR: "Trello: Uma ferramenta visual para organizar seu trabalho e vida."
icon: icon.svg
+ tags:
+ - productivity
credentials_for_provider:
trello_api_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py
index 24502a3565..48e3876a4a 100644
--- a/api/core/tools/provider/builtin/twilio/tools/send_message.py
+++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py
@@ -30,7 +30,7 @@ class TwilioAPIWrapper(BaseModel):
Twilio also work here. You cannot, for example, spoof messages from a private
cell phone number. If you are using `messaging_service_sid`, this parameter
must be empty.
- """ # noqa: E501
+ """
@validator("client", pre=True, always=True)
def set_validator(cls, values: dict) -> dict:
@@ -60,7 +60,7 @@ class TwilioAPIWrapper(BaseModel):
SMS/MMS or
[Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses)
for other 3rd-party channels.
- """ # noqa: E501
+ """
message = self.client.messages.create(to, from_=self.from_number, body=body)
return message.sid
diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py
index 7984d7b3b1..06f276053a 100644
--- a/api/core/tools/provider/builtin/twilio/twilio.py
+++ b/api/core/tools/provider/builtin/twilio/twilio.py
@@ -26,4 +26,5 @@ class TwilioProvider(BuiltinToolProviderController):
except KeyError as e:
raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/twilio/twilio.yaml b/api/core/tools/provider/builtin/twilio/twilio.yaml
index b5143c8736..21867c1da5 100644
--- a/api/core/tools/provider/builtin/twilio/twilio.yaml
+++ b/api/core/tools/provider/builtin/twilio/twilio.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: 通过SMS或Twilio消息通道发送消息。
pt_BR: Send messages through SMS or Twilio Messaging Channels.
icon: icon.svg
+ tags:
+ - social
credentials_for_provider:
account_sid:
type: secret-input
diff --git a/api/core/tools/provider/builtin/vanna/_assets/icon.png b/api/core/tools/provider/builtin/vanna/_assets/icon.png
new file mode 100644
index 0000000000..3a9011b54d
Binary files /dev/null and b/api/core/tools/provider/builtin/vanna/_assets/icon.png differ
diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py
new file mode 100644
index 0000000000..a6efb0f79a
--- /dev/null
+++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py
@@ -0,0 +1,129 @@
+from typing import Any, Union
+
+from vanna.remote import VannaDefault
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class VannaTool(BuiltinTool):
+ def _invoke(
+ self, user_id: str, tool_parameters: dict[str, Any]
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ """
+ invoke tools
+ """
+ api_key = self.runtime.credentials.get("api_key", None)
+ if not api_key:
+ raise ToolProviderCredentialValidationError("Please input api key")
+
+ model = tool_parameters.get("model", "")
+ if not model:
+ return self.create_text_message("Please input RAG model")
+
+ prompt = tool_parameters.get("prompt", "")
+ if not prompt:
+ return self.create_text_message("Please input prompt")
+
+ url = tool_parameters.get("url", "")
+ if not url:
+ return self.create_text_message("Please input URL/Host/DSN")
+
+ db_name = tool_parameters.get("db_name", "")
+ username = tool_parameters.get("username", "")
+ password = tool_parameters.get("password", "")
+ port = tool_parameters.get("port", 0)
+
+ vn = VannaDefault(model=model, api_key=api_key)
+
+ db_type = tool_parameters.get("db_type", "")
+ if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]:
+ if not db_name:
+ return self.create_text_message("Please input database name")
+ if not username:
+ return self.create_text_message("Please input username")
+ if port < 1:
+ return self.create_text_message("Please input port")
+
+ schema_sql = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS"
+ match db_type:
+ case "SQLite":
+ schema_sql = "SELECT type, sql FROM sqlite_master WHERE sql is not null"
+ vn.connect_to_sqlite(url)
+ case "Postgres":
+ vn.connect_to_postgres(host=url, dbname=db_name, user=username, password=password, port=port)
+ case "DuckDB":
+ vn.connect_to_duckdb(url=url)
+ case "SQLServer":
+ vn.connect_to_mssql(url)
+ case "MySQL":
+ vn.connect_to_mysql(host=url, dbname=db_name, user=username, password=password, port=port)
+ case "Oracle":
+ vn.connect_to_oracle(user=username, password=password, dsn=url)
+ case "Hive":
+ vn.connect_to_hive(host=url, dbname=db_name, user=username, password=password, port=port)
+ case "ClickHouse":
+ vn.connect_to_clickhouse(host=url, dbname=db_name, user=username, password=password, port=port)
+
+ enable_training = tool_parameters.get("enable_training", False)
+ reset_training_data = tool_parameters.get("reset_training_data", False)
+ if enable_training:
+ if reset_training_data:
+ existing_training_data = vn.get_training_data()
+ if len(existing_training_data) > 0:
+ for _, training_data in existing_training_data.iterrows():
+ vn.remove_training_data(training_data["id"])
+
+ ddl = tool_parameters.get("ddl", "")
+ question = tool_parameters.get("question", "")
+ sql = tool_parameters.get("sql", "")
+ memos = tool_parameters.get("memos", "")
+ training_metadata = tool_parameters.get("training_metadata", False)
+
+ if training_metadata:
+ if db_type == "SQLite":
+ df_ddl = vn.run_sql(schema_sql)
+ for ddl in df_ddl["sql"].to_list():
+ vn.train(ddl=ddl)
+ else:
+ df_information_schema = vn.run_sql(schema_sql)
+ plan = vn.get_training_plan_generic(df_information_schema)
+ vn.train(plan=plan)
+
+ if ddl:
+ vn.train(ddl=ddl)
+
+ if sql:
+ if question:
+ vn.train(question=question, sql=sql)
+ else:
+ vn.train(sql=sql)
+ if memos:
+ vn.train(documentation=memos)
+
+ #########################################################################################
+ # Due to CVE-2024-5565, we have to disable the chart generation feature
+ # The Vanna library uses a prompt function to present the user with visualized results,
+ # it is possible to alter the prompt using prompt injection and run arbitrary Python code
+ # instead of the intended visualization code.
+ # Specifically - allowing external input to the library’s “ask” method
+ # with "visualize" set to True (default behavior) leads to remote code execution.
+ # Affected versions: <= 0.5.5
+ #########################################################################################
+ generate_chart = False
+ # generate_chart = tool_parameters.get("generate_chart", True)
+ res = vn.ask(prompt, False, True, generate_chart)
+
+ result = []
+
+ if res is not None:
+ result.append(self.create_text_message(res[0]))
+ if len(res) > 1 and res[1] is not None:
+ result.append(self.create_text_message(res[1].to_markdown()))
+ if len(res) > 2 and res[2] is not None:
+ result.append(
+ self.create_blob_message(blob=res[2].to_image(format="svg"), meta={"mime_type": "image/svg+xml"})
+ )
+
+ return result
diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.yaml b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml
new file mode 100644
index 0000000000..ae2eae94c4
--- /dev/null
+++ b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml
@@ -0,0 +1,213 @@
+identity:
+ name: vanna
+ author: QCTC
+ label:
+ en_US: Vanna.AI
+ zh_Hans: Vanna.AI
+description:
+ human:
+ en_US: The fastest way to get actionable insights from your database just by asking questions.
+ zh_Hans: 一个基于大模型和RAG的Text2SQL工具。
+ llm: A tool for converting text to SQL.
+parameters:
+ - name: prompt
+ type: string
+ required: true
+ label:
+ en_US: Prompt
+ zh_Hans: 提示词
+ pt_BR: Prompt
+ human_description:
+ en_US: used for generating SQL
+ zh_Hans: 用于生成SQL
+ llm_description: key words for generating SQL
+ form: llm
+ - name: model
+ type: string
+ required: true
+ label:
+ en_US: RAG Model
+ zh_Hans: RAG模型
+ human_description:
+ en_US: RAG Model for your database DDL
+ zh_Hans: 存储数据库训练数据的RAG模型
+ llm_description: RAG Model for generating SQL
+ form: form
+ - name: db_type
+ type: select
+ required: true
+ options:
+ - value: SQLite
+ label:
+ en_US: SQLite
+ zh_Hans: SQLite
+ - value: Postgres
+ label:
+ en_US: Postgres
+ zh_Hans: Postgres
+ - value: DuckDB
+ label:
+ en_US: DuckDB
+ zh_Hans: DuckDB
+ - value: SQLServer
+ label:
+ en_US: Microsoft SQL Server
+ zh_Hans: 微软 SQL Server
+ - value: MySQL
+ label:
+ en_US: MySQL
+ zh_Hans: MySQL
+ - value: Oracle
+ label:
+ en_US: Oracle
+ zh_Hans: Oracle
+ - value: Hive
+ label:
+ en_US: Hive
+ zh_Hans: Hive
+ - value: ClickHouse
+ label:
+ en_US: ClickHouse
+ zh_Hans: ClickHouse
+ default: SQLite
+ label:
+ en_US: DB Type
+ zh_Hans: 数据库类型
+ human_description:
+ en_US: Database type.
+ zh_Hans: 选择要链接的数据库类型。
+ form: form
+ - name: url
+ type: string
+ required: true
+ label:
+ en_US: URL/Host/DSN
+ zh_Hans: URL/Host/DSN
+ human_description:
+ en_US: Please input depending on DB type, visit https://vanna.ai/docs/ for more specification
+ zh_Hans: 请根据数据库类型,填入对应值,详情参考https://vanna.ai/docs/
+ form: form
+ - name: db_name
+ type: string
+ required: false
+ label:
+ en_US: DB name
+ zh_Hans: 数据库名
+ human_description:
+ en_US: Database name
+ zh_Hans: 数据库名
+ form: form
+ - name: username
+ type: string
+ required: false
+ label:
+ en_US: Username
+ zh_Hans: 用户名
+ human_description:
+ en_US: Username
+ zh_Hans: 用户名
+ form: form
+ - name: password
+ type: secret-input
+ required: false
+ label:
+ en_US: Password
+ zh_Hans: 密码
+ human_description:
+ en_US: Password
+ zh_Hans: 密码
+ form: form
+ - name: port
+ type: number
+ required: false
+ label:
+ en_US: Port
+ zh_Hans: 端口
+ human_description:
+ en_US: Port
+ zh_Hans: 端口
+ form: form
+ - name: ddl
+ type: string
+ required: false
+ label:
+ en_US: Training DDL
+ zh_Hans: 训练DDL
+ human_description:
+ en_US: DDL statements for training data
+ zh_Hans: 用于训练RAG Model的建表语句
+ form: form
+ - name: question
+ type: string
+ required: false
+ label:
+ en_US: Training Question
+ zh_Hans: 训练问题
+ human_description:
+ en_US: Question-SQL Pairs
+ zh_Hans: Question-SQL中的问题
+ form: form
+ - name: sql
+ type: string
+ required: false
+ label:
+ en_US: Training SQL
+ zh_Hans: 训练SQL
+ human_description:
+ en_US: SQL queries to your training data
+ zh_Hans: 用于训练RAG Model的SQL语句
+ form: form
+ - name: memos
+ type: string
+ required: false
+ label:
+ en_US: Training Memos
+ zh_Hans: 训练说明
+ human_description:
+ en_US: Sometimes you may want to add documentation about your business terminology or definitions
+ zh_Hans: 添加更多关于数据库的业务说明
+ form: form
+ - name: enable_training
+ type: boolean
+ required: false
+ default: false
+ label:
+ en_US: Training Data
+ zh_Hans: 训练数据
+ human_description:
+ en_US: You only need to train once. Do not train again unless you want to add more training data
+ zh_Hans: 训练数据无更新时,训练一次即可
+ form: form
+ - name: reset_training_data
+ type: boolean
+ required: false
+ default: false
+ label:
+ en_US: Reset Training Data
+ zh_Hans: 重置训练数据
+ human_description:
+ en_US: Remove all training data in the current RAG Model
+ zh_Hans: 删除当前RAG Model中的所有训练数据
+ form: form
+ - name: training_metadata
+ type: boolean
+ required: false
+ default: false
+ label:
+ en_US: Training Metadata
+ zh_Hans: 训练元数据
+ human_description:
+ en_US: If enabled, it will attempt to train on the metadata of that database
+ zh_Hans: 是否自动从数据库获取元数据来训练
+ form: form
+ - name: generate_chart
+ type: boolean
+ required: false
+ default: True
+ label:
+ en_US: Generate Charts
+ zh_Hans: 生成图表
+ human_description:
+ en_US: Generate Charts
+ zh_Hans: 是否生成图表
+ form: form
diff --git a/api/core/tools/provider/builtin/vanna/vanna.py b/api/core/tools/provider/builtin/vanna/vanna.py
new file mode 100644
index 0000000000..ab1fd71df5
--- /dev/null
+++ b/api/core/tools/provider/builtin/vanna/vanna.py
@@ -0,0 +1,25 @@
+from typing import Any
+
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+class VannaProvider(BuiltinToolProviderController):
+ def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+ try:
+ VannaTool().fork_tool_runtime(
+ runtime={
+ "credentials": credentials,
+ }
+ ).invoke(
+ user_id='',
+ tool_parameters={
+ "model": "chinook",
+ "db_type": "SQLite",
+ "url": "https://vanna.ai/Chinook.sqlite",
+ "query": "What are the top 10 customers by sales?"
+ },
+ )
+ except Exception as e:
+ raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/vanna/vanna.yaml b/api/core/tools/provider/builtin/vanna/vanna.yaml
new file mode 100644
index 0000000000..b29fa103e1
--- /dev/null
+++ b/api/core/tools/provider/builtin/vanna/vanna.yaml
@@ -0,0 +1,25 @@
+identity:
+ author: QCTC
+ name: vanna
+ label:
+ en_US: Vanna.AI
+ zh_Hans: Vanna.AI
+ description:
+ en_US: The fastest way to get actionable insights from your database just by asking questions.
+ zh_Hans: 一个基于大模型和RAG的Text2SQL工具。
+ icon: icon.png
+credentials_for_provider:
+ api_key:
+ type: secret-input
+ required: true
+ label:
+ en_US: API key
+ zh_Hans: API key
+ placeholder:
+ en_US: Please input your API key
+ zh_Hans: 请输入你的 API key
+ pt_BR: Please input your API key
+ help:
+ en_US: Get your API key from Vanna.AI
+ zh_Hans: 从 Vanna.AI 获取你的 API key
+ url: https://vanna.ai/account/profile
diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py
index df996b5283..c6ec198034 100644
--- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py
+++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py
@@ -71,6 +71,4 @@ class VectorizerTool(BuiltinTool):
options=[i.name for i in self.list_default_image_variables()]
)
]
-
- def is_tool_available(self) -> bool:
- return len(self.list_default_image_variables()) > 0
\ No newline at end of file
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py
index 2b4d71e058..3f89a83500 100644
--- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py
+++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py
@@ -9,7 +9,7 @@ class VectorizerProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
VectorizerTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -20,4 +20,5 @@ class VectorizerProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml b/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml
index 07a20380e9..1257f8d285 100644
--- a/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml
+++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml
@@ -10,6 +10,9 @@ identity:
zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。
pt_BR: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI.
icon: icon.png
+ tags:
+ - productivity
+ - image
credentials_for_provider:
api_key_name:
type: secret-input
diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py
index 5e8c405b47..3d098e6768 100644
--- a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py
+++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py
@@ -7,9 +7,9 @@ from core.tools.tool.builtin_tool import BuiltinTool
class WebscraperTool(BuiltinTool):
def _invoke(self,
- user_id: str,
- tool_parameters: dict[str, Any],
- ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ user_id: str,
+ tool_parameters: dict[str, Any],
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
@@ -18,12 +18,15 @@ class WebscraperTool(BuiltinTool):
user_agent = tool_parameters.get('user_agent', '')
if not url:
return self.create_text_message('Please input url')
-
+
# get webpage
result = self.get_url(url, user_agent=user_agent)
- # summarize and return
- return self.create_text_message(self.summary(user_id=user_id, content=result))
+ if tool_parameters.get('generate_summary'):
+ # summarize and return
+ return self.create_text_message(self.summary(user_id=user_id, content=result))
+ else:
+ # return full webpage
+ return self.create_text_message(result)
except Exception as e:
raise ToolInvokeError(str(e))
-
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.yaml b/api/core/tools/provider/builtin/webscraper/tools/webscraper.yaml
index 5782dbb0c7..180cfec6fc 100644
--- a/api/core/tools/provider/builtin/webscraper/tools/webscraper.yaml
+++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.yaml
@@ -38,3 +38,23 @@ parameters:
pt_BR: used for identifying the browser.
form: form
default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36
+ - name: generate_summary
+ type: boolean
+ required: false
+ label:
+ en_US: Whether to generate summary
+ zh_Hans: 是否生成摘要
+ human_description:
+ en_US: If true, the crawler will only return the page summary content.
+ zh_Hans: 如果启用,爬虫将仅返回页面摘要内容。
+ form: form
+ options:
+ - value: true
+ label:
+ en_US: Yes
+ zh_Hans: 是
+ - value: false
+ label:
+ en_US: No
+ zh_Hans: 否
+ default: false
diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.py b/api/core/tools/provider/builtin/webscraper/webscraper.py
index 8761493e3b..1e60fdb293 100644
--- a/api/core/tools/provider/builtin/webscraper/webscraper.py
+++ b/api/core/tools/provider/builtin/webscraper/webscraper.py
@@ -9,7 +9,7 @@ class WebscraperProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
WebscraperTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -20,4 +20,5 @@ class WebscraperProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.yaml b/api/core/tools/provider/builtin/webscraper/webscraper.yaml
index 056108f036..6c2eb97784 100644
--- a/api/core/tools/provider/builtin/webscraper/webscraper.yaml
+++ b/api/core/tools/provider/builtin/webscraper/webscraper.yaml
@@ -10,4 +10,6 @@ identity:
zh_Hans: 一个用于抓取网页的工具。
pt_BR: Web Scrapper tool kit is used to scrape web
icon: icon.svg
+ tags:
+ - productivity
credentials_for_provider:
diff --git a/api/core/tools/provider/builtin/wecom/wecom.py b/api/core/tools/provider/builtin/wecom/wecom.py
index 7a2576b668..573f76ee56 100644
--- a/api/core/tools/provider/builtin/wecom/wecom.py
+++ b/api/core/tools/provider/builtin/wecom/wecom.py
@@ -5,4 +5,3 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class WecomProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
WecomGroupBotTool()
- pass
diff --git a/api/core/tools/provider/builtin/wecom/wecom.yaml b/api/core/tools/provider/builtin/wecom/wecom.yaml
index 39d00032a0..a544055ba4 100644
--- a/api/core/tools/provider/builtin/wecom/wecom.yaml
+++ b/api/core/tools/provider/builtin/wecom/wecom.yaml
@@ -10,4 +10,6 @@ identity:
zh_Hans: 企业微信群机器人
pt_BR: Wecom group bot
icon: icon.png
+ tags:
+ - social
credentials_for_provider:
diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.py b/api/core/tools/provider/builtin/wikipedia/wikipedia.py
index 8d53852255..f8038714a5 100644
--- a/api/core/tools/provider/builtin/wikipedia/wikipedia.py
+++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.py
@@ -7,7 +7,7 @@ class WikiPediaProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
WikiPediaSearchTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -17,4 +17,5 @@ class WikiPediaProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.yaml b/api/core/tools/provider/builtin/wikipedia/wikipedia.yaml
index f4b5da8947..c582824022 100644
--- a/api/core/tools/provider/builtin/wikipedia/wikipedia.yaml
+++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.yaml
@@ -10,4 +10,6 @@ identity:
zh_Hans: 维基百科是一个由全世界的志愿者创建和编辑的免费在线百科全书。
pt_BR: Wikipedia is a free online encyclopedia, created and edited by volunteers around the world.
icon: icon.svg
+ tags:
+ - social
credentials_for_provider:
diff --git a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py
index 7512710515..8cb9c10ddf 100644
--- a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py
+++ b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py
@@ -48,7 +48,7 @@ class WolframAlphaTool(BuiltinTool):
if 'success' not in response_data['queryresult'] or response_data['queryresult']['success'] != True:
query_result = response_data.get('queryresult', {})
- if 'error' in query_result and query_result['error']:
+ if query_result.get('error'):
if 'msg' in query_result['error']:
if query_result['error']['msg'] == 'Invalid appid':
raise ToolProviderCredentialValidationError('Invalid appid')
diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py
index 4e8213d90c..ef1aac7ff2 100644
--- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py
+++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py
@@ -9,7 +9,7 @@ class GoogleProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
WolframAlphaTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -19,4 +19,5 @@ class GoogleProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.yaml b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.yaml
index e4cc465712..91265eb3c0 100644
--- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.yaml
+++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.yaml
@@ -10,6 +10,9 @@ identity:
zh_Hans: WolframAlpha 是一个强大的计算知识引擎。
pt_BR: WolframAlpha is a powerful computational knowledge engine.
icon: icon.svg
+ tags:
+ - productivity
+ - utilities
credentials_for_provider:
appid:
type: secret-input
diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.py b/api/core/tools/provider/builtin/yahoo/yahoo.py
index ade33ffb63..96dbc6c3d0 100644
--- a/api/core/tools/provider/builtin/yahoo/yahoo.py
+++ b/api/core/tools/provider/builtin/yahoo/yahoo.py
@@ -7,7 +7,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
YahooFinanceSearchTickerTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -17,4 +17,5 @@ class YahooFinanceProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.yaml b/api/core/tools/provider/builtin/yahoo/yahoo.yaml
index b16517eaf9..f1e82952c0 100644
--- a/api/core/tools/provider/builtin/yahoo/yahoo.yaml
+++ b/api/core/tools/provider/builtin/yahoo/yahoo.yaml
@@ -10,4 +10,7 @@ identity:
zh_Hans: 雅虎财经,获取并整理出最新的新闻、股票报价等一切你想要的财经信息。
pt_BR: Finance, and Yahoo! get the latest news, stock quotes, and interactive chart with Yahoo!
icon: icon.png
+ tags:
+ - business
+ - finance
credentials_for_provider:
diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py
index 86160dfa6c..7a9b9fce4a 100644
--- a/api/core/tools/provider/builtin/youtube/tools/videos.py
+++ b/api/core/tools/provider/builtin/youtube/tools/videos.py
@@ -36,7 +36,7 @@ class YoutubeVideosAnalyticsTool(BuiltinTool):
youtube = build('youtube', 'v3', developerKey=self.runtime.credentials['google_api_key'])
# try to get channel id
- search_results = youtube.search().list(q='mrbeast', type='channel', order='relevance', part='id').execute()
+ search_results = youtube.search().list(q=channel, type='channel', order='relevance', part='id').execute()
channel_id = search_results['items'][0]['id']['channelId']
start_date, end_date = time_range
diff --git a/api/core/tools/provider/builtin/youtube/youtube.py b/api/core/tools/provider/builtin/youtube/youtube.py
index 8cca578c46..83a4fccb32 100644
--- a/api/core/tools/provider/builtin/youtube/youtube.py
+++ b/api/core/tools/provider/builtin/youtube/youtube.py
@@ -7,7 +7,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
YoutubeVideosAnalyticsTool().fork_tool_runtime(
- meta={
+ runtime={
"credentials": credentials,
}
).invoke(
@@ -19,4 +19,5 @@ class YahooFinanceProvider(BuiltinToolProviderController):
},
)
except Exception as e:
- raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
+ raise ToolProviderCredentialValidationError(str(e))
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/youtube/youtube.yaml b/api/core/tools/provider/builtin/youtube/youtube.yaml
index 2f83ae43ee..d6915b9a32 100644
--- a/api/core/tools/provider/builtin/youtube/youtube.yaml
+++ b/api/core/tools/provider/builtin/youtube/youtube.yaml
@@ -10,6 +10,8 @@ identity:
zh_Hans: YouTube(油管)是全球最大的视频分享网站,用户可以在上面上传、观看和分享视频。
pt_BR: YouTube é o maior site de compartilhamento de vídeos do mundo, onde os usuários podem fazer upload, assistir e compartilhar vídeos.
icon: icon.svg
+ tags:
+ - videos
credentials_for_provider:
google_api_key:
type: secret-input
diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py
index c2178cdd40..d076cb384f 100644
--- a/api/core/tools/provider/builtin_tool_provider.py
+++ b/api/core/tools/provider/builtin_tool_provider.py
@@ -2,25 +2,24 @@ from abc import abstractmethod
from os import listdir, path
from typing import Any
-from yaml import FullLoader, load
-
+from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
-from core.tools.entities.user_entities import UserToolProviderCredentials
+from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolNotFoundError,
ToolParameterValidationError,
- ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
-from core.utils.module_import_helper import load_single_subclass_from_source
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
+from core.tools.utils.yaml_utils import load_yaml_file
class BuiltinToolProviderController(ToolProviderController):
def __init__(self, **data: Any) -> None:
- if self.app_type == ToolProviderType.API_BASED or self.app_type == ToolProviderType.APP_BASED:
+ if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
super().__init__(**data)
return
@@ -28,10 +27,9 @@ class BuiltinToolProviderController(ToolProviderController):
provider = self.__class__.__module__.split('.')[-1]
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
try:
- with open(yaml_path, 'rb') as f:
- provider_yaml = load(f.read(), FullLoader)
- except:
- raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
+ provider_yaml = load_yaml_file(yaml_path)
+ except Exception as e:
+ raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')
if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
# set credentials name
@@ -58,18 +56,18 @@ class BuiltinToolProviderController(ToolProviderController):
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
tools = []
for tool_file in tool_files:
- with open(path.join(tool_path, tool_file), encoding='utf-8') as f:
- # get tool name
- tool_name = tool_file.split(".")[0]
- tool = load(f.read(), FullLoader)
- # get tool class, import the module
- assistant_tool_class = load_single_subclass_from_source(
- module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
- script_path=path.join(path.dirname(path.realpath(__file__)),
- 'builtin', provider, 'tools', f'{tool_name}.py'),
- parent_type=BuiltinTool)
- tool["identity"]["provider"] = provider
- tools.append(assistant_tool_class(**tool))
+ # get tool name
+ tool_name = tool_file.split(".")[0]
+ tool = load_yaml_file(path.join(tool_path, tool_file))
+
+ # get tool class, import the module
+ assistant_tool_class = load_single_subclass_from_source(
+ module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
+ script_path=path.join(path.dirname(path.realpath(__file__)),
+ 'builtin', provider, 'tools', f'{tool_name}.py'),
+ parent_type=BuiltinTool)
+ tool["identity"]["provider"] = provider
+ tools.append(assistant_tool_class(**tool))
self.tools = tools
return tools
@@ -84,15 +82,6 @@ class BuiltinToolProviderController(ToolProviderController):
return {}
return self.credentials_schema.copy()
-
- def user_get_credentials_schema(self) -> UserToolProviderCredentials:
- """
- returns the credentials schema of the provider, this method is used for user
-
- :return: the credentials schema
- """
- credentials = self.credentials_schema.copy()
- return UserToolProviderCredentials(credentials=credentials)
def get_tools(self) -> list[Tool]:
"""
@@ -131,7 +120,7 @@ class BuiltinToolProviderController(ToolProviderController):
len(self.credentials_schema) != 0
@property
- def app_type(self) -> ToolProviderType:
+ def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
@@ -139,6 +128,22 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return ToolProviderType.BUILT_IN
+ @property
+ def tool_labels(self) -> list[str]:
+ """
+ returns the labels of the provider
+
+ :return: labels of the provider
+ """
+ label_enums = self._get_tool_labels()
+ return [default_tool_label_dict[label].name for label in label_enums]
+
+ def _get_tool_labels(self) -> list[ToolLabelEnum]:
+ """
+ returns the labels of the provider
+ """
+ return self.identity.tags or []
+
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
"""
validate the parameters of the tool and set the default value if needed
@@ -196,91 +201,9 @@ class BuiltinToolProviderController(ToolProviderController):
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
- default_value = parameter_schema.default
- # parse default value into the correct type
- if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \
- parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
- default_value = str(default_value)
- elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
- default_value = float(default_value)
- elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
- default_value = bool(default_value)
-
+ default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
+ parameter_schema.type)
tool_parameters[parameter] = default_value
-
- def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
- """
- validate the format of the credentials of the provider and set the default value if needed
-
- :param credentials: the credentials of the tool
- """
- credentials_schema = self.credentials_schema
- if credentials_schema is None:
- return
-
- credentials_need_to_validate: dict[str, ToolProviderCredentials] = {}
- for credential_name in credentials_schema:
- credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
-
- for credential_name in credentials:
- if credential_name not in credentials_need_to_validate:
- raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}')
-
- # check type
- credential_schema = credentials_need_to_validate[credential_name]
- if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
- credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
- if not isinstance(credentials[credential_name], str):
- raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be string')
-
- elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
- if not isinstance(credentials[credential_name], str):
- raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be string')
-
- options = credential_schema.options
- if not isinstance(options, list):
- raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} options should be list')
-
- if credentials[credential_name] not in [x.value for x in options]:
- raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}')
- elif credential_schema.type == ToolProviderCredentials.CredentialsType.BOOLEAN:
- if isinstance(credentials[credential_name], bool):
- pass
- elif isinstance(credentials[credential_name], str):
- if credentials[credential_name].lower() == 'true':
- credentials[credential_name] = True
- elif credentials[credential_name].lower() == 'false':
- credentials[credential_name] = False
- else:
- raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
- elif isinstance(credentials[credential_name], int):
- if credentials[credential_name] == 1:
- credentials[credential_name] = True
- elif credentials[credential_name] == 0:
- credentials[credential_name] = False
- else:
- raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
- else:
- raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
-
- if credentials[credential_name] or credentials[credential_name] == False:
- credentials_need_to_validate.pop(credential_name)
-
- for credential_name in credentials_need_to_validate:
- credential_schema = credentials_need_to_validate[credential_name]
- if credential_schema.required:
- raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} is required')
-
- # the credential is not set currently, set the default value if needed
- if credential_schema.default is not None:
- default_value = credential_schema.default
- # parse default value into the correct type
- if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
- credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
- credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
- default_value = str(default_value)
-
- credentials[credential_name] = default_value
def validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
diff --git a/api/core/tools/provider/model_tool_provider.py b/api/core/tools/provider/model_tool_provider.py
deleted file mode 100644
index ef47e9aae9..0000000000
--- a/api/core/tools/provider/model_tool_provider.py
+++ /dev/null
@@ -1,244 +0,0 @@
-from copy import deepcopy
-from typing import Any
-
-from core.entities.model_entities import ModelStatus
-from core.errors.error import ProviderTokenNotInitError
-from core.model_manager import ModelInstance
-from core.model_runtime.entities.model_entities import ModelFeature, ModelType
-from core.provider_manager import ProviderConfiguration, ProviderManager, ProviderModelBundle
-from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_entities import (
- ModelToolPropertyKey,
- ToolDescription,
- ToolIdentity,
- ToolParameter,
- ToolProviderCredentials,
- ToolProviderIdentity,
- ToolProviderType,
-)
-from core.tools.errors import ToolNotFoundError
-from core.tools.provider.tool_provider import ToolProviderController
-from core.tools.tool.model_tool import ModelTool
-from core.tools.tool.tool import Tool
-from core.tools.utils.configuration import ModelToolConfigurationManager
-
-
-class ModelToolProviderController(ToolProviderController):
- configuration: ProviderConfiguration = None
- is_active: bool = False
-
- def __init__(self, configuration: ProviderConfiguration = None, **kwargs):
- """
- init the provider
-
- :param data: the data of the provider
- """
- super().__init__(**kwargs)
- self.configuration = configuration
-
- @staticmethod
- def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController':
- """
- init the provider from db
-
- :param configuration: the configuration of the provider
- """
- # check if all models are active
- if configuration is None:
- return None
- is_active = True
- models = configuration.get_provider_models()
- for model in models:
- if model.status != ModelStatus.ACTIVE:
- is_active = False
- break
-
- # get the provider configuration
- model_tool_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
- if model_tool_configuration is None:
- raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
-
- # override the configuration
- if model_tool_configuration.label:
- label = deepcopy(model_tool_configuration.label)
- if label.en_US:
- label.en_US = model_tool_configuration.label.en_US
- if label.zh_Hans:
- label.zh_Hans = model_tool_configuration.label.zh_Hans
- else:
- label = I18nObject(
- en_US=configuration.provider.label.en_US,
- zh_Hans=configuration.provider.label.zh_Hans
- )
-
- return ModelToolProviderController(
- is_active=is_active,
- identity=ToolProviderIdentity(
- author='Dify',
- name=configuration.provider.provider,
- description=I18nObject(
- zh_Hans=f'{label.zh_Hans} 模型能力提供商',
- en_US=f'{label.en_US} model capability provider'
- ),
- label=I18nObject(
- zh_Hans=label.zh_Hans,
- en_US=label.en_US
- ),
- icon=configuration.provider.icon_small.en_US,
- ),
- configuration=configuration,
- credentials_schema={},
- )
-
- @staticmethod
- def is_configuration_valid(configuration: ProviderConfiguration) -> bool:
- """
- check if the configuration has a model can be used as a tool
- """
- models = configuration.get_provider_models()
- for model in models:
- if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
- return True
- return False
-
- def _get_model_tools(self, tenant_id: str = None) -> list[ModelTool]:
- """
- returns a list of tools that the provider can provide
-
- :return: list of tools
- """
- tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
- provider_manager = ProviderManager()
- if self.configuration is None:
- configurations = provider_manager.get_configurations(tenant_id=tenant_id).values()
- self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None)
- # get all tools
- tools: list[ModelTool] = []
- # get all models
- if not self.configuration:
- return tools
- configuration = self.configuration
-
- provider_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
- if provider_configuration is None:
- raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
-
- for model in configuration.get_provider_models():
- model_configuration = ModelToolConfigurationManager.get_model_configuration(self.configuration.provider.provider, model.model)
- if model_configuration is None:
- continue
-
- if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
- provider_instance = configuration.get_provider_instance()
- model_type_instance = provider_instance.get_model_instance(model.model_type)
- provider_model_bundle = ProviderModelBundle(
- configuration=configuration,
- provider_instance=provider_instance,
- model_type_instance=model_type_instance
- )
-
- try:
- model_instance = ModelInstance(provider_model_bundle, model.model)
- except ProviderTokenNotInitError:
- model_instance = None
-
- tools.append(ModelTool(
- identity=ToolIdentity(
- author='Dify',
- name=model.model,
- label=model_configuration.label,
- ),
- parameters=[
- ToolParameter(
- name=ModelToolPropertyKey.IMAGE_PARAMETER_NAME.value,
- label=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
- human_description=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
- type=ToolParameter.ToolParameterType.STRING,
- form=ToolParameter.ToolParameterForm.LLM,
- required=True,
- default=Tool.VARIABLE_KEY.IMAGE.value
- )
- ],
- description=ToolDescription(
- human=I18nObject(zh_Hans='图生文工具', en_US='Convert image to text'),
- llm='Vision tool used to extract text and other visual information from images, can be used for OCR, image captioning, etc.',
- ),
- is_team_authorization=model.status == ModelStatus.ACTIVE,
- tool_type=ModelTool.ModelToolType.VISION,
- model_instance=model_instance,
- model=model.model,
- ))
-
- self.tools = tools
- return tools
-
- def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
- """
- returns the credentials schema of the provider
-
- :return: the credentials schema
- """
- return {}
-
- def get_tools(self, user_id: str, tenant_id: str) -> list[ModelTool]:
- """
- returns a list of tools that the provider can provide
-
- :return: list of tools
- """
- return self._get_model_tools(tenant_id=tenant_id)
-
- def get_tool(self, tool_name: str) -> ModelTool:
- """
- get tool by name
-
- :param tool_name: the name of the tool
- :return: the tool
- """
- if self.tools is None:
- self.get_tools(user_id='', tenant_id=self.configuration.tenant_id)
-
- for tool in self.tools:
- if tool.identity.name == tool_name:
- return tool
-
- raise ValueError(f'tool {tool_name} not found')
-
- def get_parameters(self, tool_name: str) -> list[ToolParameter]:
- """
- returns the parameters of the tool
-
- :param tool_name: the name of the tool, defined in `get_tools`
- :return: list of parameters
- """
- tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
- if tool is None:
- raise ToolNotFoundError(f'tool {tool_name} not found')
- return tool.parameters
-
- @property
- def app_type(self) -> ToolProviderType:
- """
- returns the type of the provider
-
- :return: type of the provider
- """
- return ToolProviderType.MODEL
-
- def validate_credentials(self, credentials: dict[str, Any]) -> None:
- """
- validate the credentials of the provider
-
- :param tool_name: the name of the tool, defined in `get_tools`
- :param credentials: the credentials of the tool
- """
- pass
-
- def _validate_credentials(self, credentials: dict[str, Any]) -> None:
- """
- validate the credentials of the provider
-
- :param tool_name: the name of the tool, defined in `get_tools`
- :param credentials: the credentials of the tool
- """
- pass
\ No newline at end of file
diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py
index b527f2b274..ef1ace9c7c 100644
--- a/api/core/tools/provider/tool_provider.py
+++ b/api/core/tools/provider/tool_provider.py
@@ -9,9 +9,9 @@ from core.tools.entities.tool_entities import (
ToolProviderIdentity,
ToolProviderType,
)
-from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
class ToolProviderController(BaseModel, ABC):
@@ -27,15 +27,6 @@ class ToolProviderController(BaseModel, ABC):
"""
return self.credentials_schema.copy()
- def user_get_credentials_schema(self) -> UserToolProviderCredentials:
- """
- returns the credentials schema of the provider, this method is used for user
-
- :return: the credentials schema
- """
- credentials = self.credentials_schema.copy()
- return UserToolProviderCredentials(credentials=credentials)
-
@abstractmethod
def get_tools(self) -> list[Tool]:
"""
@@ -67,7 +58,7 @@ class ToolProviderController(BaseModel, ABC):
return tool.parameters
@property
- def app_type(self) -> ToolProviderType:
+ def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
@@ -132,17 +123,8 @@ class ToolProviderController(BaseModel, ABC):
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
- default_value = parameter_schema.default
- # parse default value into the correct type
- if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \
- parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
- default_value = str(default_value)
- elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
- default_value = float(default_value)
- elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
- default_value = bool(default_value)
-
- tool_parameters[parameter] = default_value
+ tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
+ parameter_schema.type)
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
@@ -197,26 +179,4 @@ class ToolProviderController(BaseModel, ABC):
default_value = str(default_value)
credentials[credential_name] = default_value
-
- def validate_credentials(self, credentials: dict[str, Any]) -> None:
- """
- validate the credentials of the provider
-
- :param tool_name: the name of the tool, defined in `get_tools`
- :param credentials: the credentials of the tool
- """
- # validate credentials format
- self.validate_credentials_format(credentials)
-
- # validate credentials
- self._validate_credentials(credentials)
-
- @abstractmethod
- def _validate_credentials(self, credentials: dict[str, Any]) -> None:
- """
- validate the credentials of the provider
-
- :param tool_name: the name of the tool, defined in `get_tools`
- :param credentials: the credentials of the tool
- """
- pass
\ No newline at end of file
+
\ No newline at end of file
diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py
new file mode 100644
index 0000000000..f98ad0f26a
--- /dev/null
+++ b/api/core/tools/provider/workflow_tool_provider.py
@@ -0,0 +1,230 @@
+from typing import Optional
+
+from core.app.app_config.entities import VariableEntity
+from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
+from core.model_runtime.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import (
+ ToolDescription,
+ ToolIdentity,
+ ToolParameter,
+ ToolParameterOption,
+ ToolProviderType,
+)
+from core.tools.provider.tool_provider import ToolProviderController
+from core.tools.tool.workflow_tool import WorkflowTool
+from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
+from extensions.ext_database import db
+from models.model import App, AppMode
+from models.tools import WorkflowToolProvider
+from models.workflow import Workflow
+
+
+class WorkflowToolProviderController(ToolProviderController):
+ provider_id: str
+
+ @classmethod
+ def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController':
+ app = db_provider.app
+
+ if not app:
+ raise ValueError('app not found')
+
+ controller = WorkflowToolProviderController(**{
+ 'identity': {
+ 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
+ 'name': db_provider.label,
+ 'label': {
+ 'en_US': db_provider.label,
+ 'zh_Hans': db_provider.label
+ },
+ 'description': {
+ 'en_US': db_provider.description,
+ 'zh_Hans': db_provider.description
+ },
+ 'icon': db_provider.icon,
+ },
+ 'credentials_schema': {},
+ 'provider_id': db_provider.id or '',
+ })
+
+ # init tools
+
+ controller.tools = [controller._get_db_provider_tool(db_provider, app)]
+
+ return controller
+
+ @property
+ def provider_type(self) -> ToolProviderType:
+ return ToolProviderType.WORKFLOW
+
+ def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
+ """
+ get db provider tool
+ :param db_provider: the db provider
+ :param app: the app
+ :return: the tool
+ """
+ workflow: Workflow = db.session.query(Workflow).filter(
+ Workflow.app_id == db_provider.app_id,
+ Workflow.version == db_provider.version
+ ).first()
+ if not workflow:
+ raise ValueError('workflow not found')
+
+ # fetch start node
+ graph: dict = workflow.graph_dict
+ features_dict: dict = workflow.features_dict
+ features = WorkflowAppConfigManager.convert_features(
+ config_dict=features_dict,
+ app_mode=AppMode.WORKFLOW
+ )
+
+ parameters = db_provider.parameter_configurations
+ variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
+
+ def fetch_workflow_variable(variable_name: str) -> VariableEntity:
+ return next(filter(lambda x: x.variable == variable_name, variables), None)
+
+ user = db_provider.user
+
+ workflow_tool_parameters = []
+ for parameter in parameters:
+ variable = fetch_workflow_variable(parameter.name)
+ if variable:
+ parameter_type = None
+ options = None
+ if variable.type in [
+ VariableEntity.Type.TEXT_INPUT,
+ VariableEntity.Type.PARAGRAPH,
+ ]:
+ parameter_type = ToolParameter.ToolParameterType.STRING
+ elif variable.type in [
+ VariableEntity.Type.SELECT
+ ]:
+ parameter_type = ToolParameter.ToolParameterType.SELECT
+ elif variable.type in [
+ VariableEntity.Type.NUMBER
+ ]:
+ parameter_type = ToolParameter.ToolParameterType.NUMBER
+ else:
+ raise ValueError(f'unsupported variable type {variable.type}')
+
+ if variable.type == VariableEntity.Type.SELECT and variable.options:
+ options = [
+ ToolParameterOption(
+ value=option,
+ label=I18nObject(
+ en_US=option,
+ zh_Hans=option
+ )
+ ) for option in variable.options
+ ]
+
+ workflow_tool_parameters.append(
+ ToolParameter(
+ name=parameter.name,
+ label=I18nObject(
+ en_US=variable.label,
+ zh_Hans=variable.label
+ ),
+ human_description=I18nObject(
+ en_US=parameter.description,
+ zh_Hans=parameter.description
+ ),
+ type=parameter_type,
+ form=parameter.form,
+ llm_description=parameter.description,
+ required=variable.required,
+ options=options,
+ default=variable.default
+ )
+ )
+ elif features.file_upload:
+ workflow_tool_parameters.append(
+ ToolParameter(
+ name=parameter.name,
+ label=I18nObject(
+ en_US=parameter.name,
+ zh_Hans=parameter.name
+ ),
+ human_description=I18nObject(
+ en_US=parameter.description,
+ zh_Hans=parameter.description
+ ),
+ type=ToolParameter.ToolParameterType.FILE,
+ llm_description=parameter.description,
+ required=False,
+ form=parameter.form,
+ )
+ )
+ else:
+ raise ValueError('variable not found')
+
+ return WorkflowTool(
+ identity=ToolIdentity(
+ author=user.name if user else '',
+ name=db_provider.name,
+ label=I18nObject(
+ en_US=db_provider.label,
+ zh_Hans=db_provider.label
+ ),
+ provider=self.provider_id,
+ icon=db_provider.icon,
+ ),
+ description=ToolDescription(
+ human=I18nObject(
+ en_US=db_provider.description,
+ zh_Hans=db_provider.description
+ ),
+ llm=db_provider.description,
+ ),
+ parameters=workflow_tool_parameters,
+ is_team_authorization=True,
+ workflow_app_id=app.id,
+ workflow_entities={
+ 'app': app,
+ 'workflow': workflow,
+ },
+ version=db_provider.version,
+ workflow_call_depth=0,
+ label=db_provider.label
+ )
+
+ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
+ """
+ fetch tools from database
+
+ :param user_id: the user id
+ :param tenant_id: the tenant id
+ :return: the tools
+ """
+ if self.tools is not None:
+ return self.tools
+
+ db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.tenant_id == tenant_id,
+ WorkflowToolProvider.app_id == self.provider_id,
+ ).first()
+
+ if not db_providers:
+ return []
+
+ self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
+
+ return self.tools
+
+ def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
+ """
+ get tool by name
+
+ :param tool_name: the name of the tool
+ :return: the tool
+ """
+ if self.tools is None:
+ return None
+
+ for tool in self.tools:
+ if tool.identity.name == tool_name:
+ return tool
+
+ return None
diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py
index f7b963a92e..ff7d4015ab 100644
--- a/api/core/tools/tool/api_tool.py
+++ b/api/core/tools/tool/api_tool.py
@@ -8,9 +8,8 @@ import httpx
import requests
import core.helper.ssrf_proxy as ssrf_proxy
-from core.tools.entities.tool_bundle import ApiBasedToolBundle
+from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
-from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool
@@ -20,12 +19,12 @@ API_TOOL_DEFAULT_TIMEOUT = (
)
class ApiTool(Tool):
- api_bundle: ApiBasedToolBundle
+ api_bundle: ApiToolBundle
"""
Api tool
"""
- def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
+ def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
@@ -37,7 +36,7 @@ class ApiTool(Tool):
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
- runtime=Tool.Runtime(**meta)
+ runtime=Tool.Runtime(**runtime)
)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
@@ -55,7 +54,7 @@ class ApiTool(Tool):
return self.validate_and_parse_response(response)
def tool_provider_type(self) -> ToolProviderType:
- return UserToolProvider.ProviderType.API
+ return ToolProviderType.API
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
headers = {}
diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py
index 68193e5f69..1ed05c112f 100644
--- a/api/core/tools/tool/builtin_tool.py
+++ b/api/core/tools/tool/builtin_tool.py
@@ -2,9 +2,8 @@
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.entities.tool_entities import ToolProviderType
-from core.tools.entities.user_entities import UserToolProvider
-from core.tools.model.tool_model_manager import ToolModelManager
from core.tools.tool.tool import Tool
+from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.tools.utils.web_reader_tool import get_url
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
@@ -34,7 +33,7 @@ class BuiltinTool(Tool):
:return: the model result
"""
# invoke model
- return ToolModelManager.invoke(
+ return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id,
tool_type='builtin',
@@ -43,7 +42,7 @@ class BuiltinTool(Tool):
)
def tool_provider_type(self) -> ToolProviderType:
- return UserToolProvider.ProviderType.BUILTIN
+ return ToolProviderType.BUILT_IN
def get_max_tokens(self) -> int:
"""
@@ -52,7 +51,7 @@ class BuiltinTool(Tool):
:param model_config: the model config
:return: the max tokens
"""
- return ToolModelManager.get_max_llm_context_tokens(
+ return ModelInvocationUtils.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id,
)
@@ -63,7 +62,7 @@ class BuiltinTool(Tool):
:param prompt_messages: the prompt messages
:return: the tokens
"""
- return ToolModelManager.calculate_tokens(
+ return ModelInvocationUtils.calculate_tokens(
tenant_id=self.runtime.tenant_id,
prompt_messages=prompt_messages
)
diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
index 6e11427d58..18cf780668 100644
--- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
+++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
@@ -7,7 +7,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
-from core.rerank.rerank import RerankRunner
+from core.rag.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@@ -79,7 +79,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_score_list = {}
for item in all_documents:
- if 'score' in item.metadata and item.metadata['score']:
+ if item.metadata.get('score'):
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
@@ -99,9 +99,9 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
float('inf')))
for segment in sorted_segments:
if segment.answer:
- document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
+ document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
else:
- document_context_list.append(segment.content)
+ document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1
diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
index 552174e0ba..af45fc66f2 100644
--- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
+++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
@@ -87,7 +87,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
- if 'score' in item.metadata and item.metadata['score']:
+ if item.metadata.get('score'):
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in documents]
@@ -105,9 +105,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
float('inf')))
for segment in sorted_segments:
if segment.answer:
- document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
+ document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
else:
- document_context_list.append(segment.content)
+ document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1
diff --git a/api/core/tools/tool/model_tool.py b/api/core/tools/tool/model_tool.py
deleted file mode 100644
index b87e85f89c..0000000000
--- a/api/core/tools/tool/model_tool.py
+++ /dev/null
@@ -1,159 +0,0 @@
-from base64 import b64encode
-from enum import Enum
-from typing import Any, cast
-
-from core.model_manager import ModelInstance
-from core.model_runtime.entities.llm_entities import LLMResult
-from core.model_runtime.entities.message_entities import (
- PromptMessageContent,
- PromptMessageContentType,
- SystemPromptMessage,
- UserPromptMessage,
-)
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage, ToolProviderType
-from core.tools.tool.tool import Tool
-
-VISION_PROMPT = """## Image Recognition Task
-### Task Description
-I require a powerful vision language model for an image recognition task. The model should be capable of extracting various details from the images, including but not limited to text content, layout distribution, color distribution, main subjects, and emotional expressions.
-### Specific Requirements
-1. **Text Content Extraction:** Ensure that the model accurately recognizes and extracts text content from the images, regardless of text size, font, or color.
-2. **Layout Distribution Analysis:** The model should analyze the layout structure of the images, capturing the relationships between various elements and providing detailed information about the image layout.
-3. **Color Distribution Analysis:** Extract information about color distribution in the images, including primary colors, color combinations, and other relevant details.
-4. **Main Subject Recognition:** The model should accurately identify the main subjects in the images and provide detailed descriptions of these subjects.
-5. **Emotional Expression Analysis:** Analyze and describe the emotions or expressions conveyed in the images based on facial expressions, postures, and other relevant features.
-### Additional Considerations
-- Ensure that the extracted information is as comprehensive and accurate as possible.
-- For each task, provide confidence scores or relevance scores for the model outputs to assess the reliability of the results.
-- If necessary, pose specific questions for different tasks to guide the model in better understanding the images and providing relevant information."""
-
-class ModelTool(Tool):
- class ModelToolType(Enum):
- """
- the type of the model tool
- """
- VISION = 'vision'
-
- model_configuration: dict[str, Any] = None
- tool_type: ModelToolType
-
- def __init__(self, model_instance: ModelInstance = None, model: str = None,
- tool_type: ModelToolType = ModelToolType.VISION,
- properties: dict[ModelToolPropertyKey, Any] = None,
- **kwargs):
- """
- init the tool
- """
- kwargs['model_configuration'] = {
- 'model_instance': model_instance,
- 'model': model,
- 'properties': properties
- }
- kwargs['tool_type'] = tool_type
- super().__init__(**kwargs)
-
- """
- Model tool
- """
- def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
- """
- fork a new tool with meta data
-
- :param meta: the meta data of a tool call processing, tenant_id is required
- :return: the new tool
- """
- return self.__class__(
- identity=self.identity.copy() if self.identity else None,
- parameters=self.parameters.copy() if self.parameters else None,
- description=self.description.copy() if self.description else None,
- model_instance=self.model_configuration['model_instance'],
- model=self.model_configuration['model'],
- tool_type=self.tool_type,
- runtime=Tool.Runtime(**meta)
- )
-
- def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> None:
- """
- validate the credentials for Model tool
- """
- pass
-
- def tool_provider_type(self) -> ToolProviderType:
- return ToolProviderType.BUILT_IN
-
- def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
- """
- """
- model_instance = self.model_configuration['model_instance']
- if not model_instance:
- return self.create_text_message('the tool is not configured correctly')
-
- if self.tool_type == ModelTool.ModelToolType.VISION:
- return self._invoke_llm_vision(user_id, tool_parameters)
- else:
- return self.create_text_message('the tool is not configured correctly')
-
- def _invoke_llm_vision(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
- # get image
- image_parameter_name = self.model_configuration['properties'].get(ModelToolPropertyKey.IMAGE_PARAMETER_NAME, 'image_id')
- image_id = tool_parameters.pop(image_parameter_name, '')
- if not image_id:
- image = self.get_default_image_variable()
- if not image:
- return self.create_text_message('Please upload an image or input image_id')
- else:
- image = self.get_variable(image_id)
- if not image:
- image = self.get_default_image_variable()
- if not image:
- return self.create_text_message('Please upload an image or input image_id')
-
- if not image:
- return self.create_text_message('Please upload an image or input image_id')
-
- # get image
- image = self.get_variable_file(image.name)
- if not image:
- return self.create_text_message('Failed to get image')
-
- # organize prompt messages
- prompt_messages = [
- SystemPromptMessage(
- content=VISION_PROMPT
- ),
- UserPromptMessage(
- content=[
- PromptMessageContent(
- type=PromptMessageContentType.TEXT,
- data='Recognize the image and extract the information from the image.'
- ),
- PromptMessageContent(
- type=PromptMessageContentType.IMAGE,
- data=f'data:image/png;base64,{b64encode(image).decode("utf-8")}'
- )
- ]
- )
- ]
-
- llm_instance = cast(LargeLanguageModel, self.model_configuration['model_instance'])
- result: LLMResult = llm_instance.invoke(
- model=self.model_configuration['model'],
- credentials=self.runtime.credentials,
- prompt_messages=prompt_messages,
- model_parameters=tool_parameters,
- tools=[],
- stop=[],
- stream=False,
- user=user_id,
- )
-
- if not result:
- return self.create_text_message('Failed to extract information from the image')
-
- # get result
- content = result.message.content
- if not content:
- return self.create_text_message('Failed to extract information from the image')
-
- return self.create_text_message(content)
\ No newline at end of file
diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py
index 03aa0623fe..5ae54e4728 100644
--- a/api/core/tools/tool/tool.py
+++ b/api/core/tools/tool/tool.py
@@ -1,12 +1,16 @@
from abc import ABC, abstractmethod
+from copy import deepcopy
from enum import Enum
from typing import Any, Optional, Union
from pydantic import BaseModel, validator
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.file.file_obj import FileVar
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
+ ToolInvokeFrom,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
@@ -15,6 +19,7 @@ from core.tools.entities.tool_entities import (
ToolRuntimeVariablePool,
)
from core.tools.tool_file_manager import ToolFileManager
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
class Tool(BaseModel, ABC):
@@ -25,10 +30,7 @@ class Tool(BaseModel, ABC):
@validator('parameters', pre=True, always=True)
def set_parameters(cls, v, values):
- if not v:
- return []
-
- return v
+ return v or []
class Runtime(BaseModel):
"""
@@ -41,6 +43,8 @@ class Tool(BaseModel, ABC):
tenant_id: str = None
tool_id: str = None
+ invoke_from: InvokeFrom = None
+ tool_invoke_from: ToolInvokeFrom = None
credentials: dict[str, Any] = None
runtime_parameters: dict[str, Any] = None
@@ -53,7 +57,7 @@ class Tool(BaseModel, ABC):
class VARIABLE_KEY(Enum):
IMAGE = 'image'
- def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
+ def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
@@ -64,7 +68,7 @@ class Tool(BaseModel, ABC):
identity=self.identity.copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
- runtime=Tool.Runtime(**meta),
+ runtime=Tool.Runtime(**runtime),
)
@abstractmethod
@@ -208,17 +212,17 @@ class Tool(BaseModel, ABC):
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK:
- result += f"result link: {response.message}. please tell user to check it."
+ result += f"result link: {response.message}. please tell user to check it. \n"
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
- result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
+ result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now. \n"
elif response.type == ToolInvokeMessage.MessageType.BLOB:
if len(response.message) > 114:
result += str(response.message[:114]) + '...'
else:
result += str(response.message)
else:
- result += f"tool response: {response.message}."
+ result += f"tool response: {response.message}. \n"
return result
@@ -226,46 +230,13 @@ class Tool(BaseModel, ABC):
"""
Transform tool parameters type
"""
+ # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
+ result = deepcopy(tool_parameters)
for parameter in self.parameters:
if parameter.name in tool_parameters:
- if parameter.type in [
- ToolParameter.ToolParameterType.SECRET_INPUT,
- ToolParameter.ToolParameterType.STRING,
- ToolParameter.ToolParameterType.SELECT,
- ] and not isinstance(tool_parameters[parameter.name], str):
- if tool_parameters[parameter.name] is None:
- tool_parameters[parameter.name] = ''
- else:
- tool_parameters[parameter.name] = str(tool_parameters[parameter.name])
- elif parameter.type == ToolParameter.ToolParameterType.NUMBER \
- and not isinstance(tool_parameters[parameter.name], int | float):
- if isinstance(tool_parameters[parameter.name], str):
- try:
- tool_parameters[parameter.name] = int(tool_parameters[parameter.name])
- except ValueError:
- tool_parameters[parameter.name] = float(tool_parameters[parameter.name])
- elif isinstance(tool_parameters[parameter.name], bool):
- tool_parameters[parameter.name] = int(tool_parameters[parameter.name])
- elif tool_parameters[parameter.name] is None:
- tool_parameters[parameter.name] = 0
- elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
- if not isinstance(tool_parameters[parameter.name], bool):
- # check if it is a string
- if isinstance(tool_parameters[parameter.name], str):
- # check true false
- if tool_parameters[parameter.name].lower() in ['true', 'false']:
- tool_parameters[parameter.name] = tool_parameters[parameter.name].lower() == 'true'
- # check 1 0
- elif tool_parameters[parameter.name] in ['1', '0']:
- tool_parameters[parameter.name] = tool_parameters[parameter.name] == '1'
- else:
- tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
- elif isinstance(tool_parameters[parameter.name], int | float):
- tool_parameters[parameter.name] = tool_parameters[parameter.name] != 0
- else:
- tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
-
- return tool_parameters
+ result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(tool_parameters[parameter.name], parameter.type)
+
+ return result
@abstractmethod
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
@@ -324,14 +295,6 @@ class Tool(BaseModel, ABC):
return parameters
- def is_tool_available(self) -> bool:
- """
- check if the tool is available
-
- :return: if the tool is available
- """
- return True
-
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
"""
create an image message
@@ -343,6 +306,14 @@ class Tool(BaseModel, ABC):
message=image,
save_as=save_as)
+ def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
+ return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
+ message='',
+ meta={
+ 'file_var': file_var
+ },
+ save_as='')
+
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a link message
@@ -361,10 +332,11 @@ class Tool(BaseModel, ABC):
:param text: the text
:return: the text message
"""
- return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT,
- message=text,
- save_as=save_as
- )
+ return ToolInvokeMessage(
+ type=ToolInvokeMessage.MessageType.TEXT,
+ message=text,
+ save_as=save_as
+ )
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
"""
@@ -373,7 +345,8 @@ class Tool(BaseModel, ABC):
:param blob: the blob
:return: the blob message
"""
- return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB,
- message=blob, meta=meta,
- save_as=save_as
- )
\ No newline at end of file
+ return ToolInvokeMessage(
+ type=ToolInvokeMessage.MessageType.BLOB,
+ message=blob, meta=meta,
+ save_as=save_as
+ )
diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py
new file mode 100644
index 0000000000..122b663f94
--- /dev/null
+++ b/api/core/tools/tool/workflow_tool.py
@@ -0,0 +1,200 @@
+import json
+import logging
+from copy import deepcopy
+from typing import Any, Union
+
+from core.file.file_obj import FileTransferMethod, FileVar
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
+from core.tools.tool.tool import Tool
+from extensions.ext_database import db
+from models.account import Account
+from models.model import App, EndUser
+from models.workflow import Workflow
+
+logger = logging.getLogger(__name__)
+
+class WorkflowTool(Tool):
+ workflow_app_id: str
+ version: str
+ workflow_entities: dict[str, Any]
+ workflow_call_depth: int
+
+ label: str
+
+ """
+ Workflow tool.
+ """
+ def tool_provider_type(self) -> ToolProviderType:
+ """
+ get the tool provider type
+
+ :return: the tool provider type
+ """
+ return ToolProviderType.WORKFLOW
+
+ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \
+ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ """
+ invoke the tool
+ """
+ app = self._get_app(app_id=self.workflow_app_id)
+ workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
+
+ # transform the tool parameters
+ tool_parameters, files = self._transform_args(tool_parameters)
+
+ from core.app.apps.workflow.app_generator import WorkflowAppGenerator
+ generator = WorkflowAppGenerator()
+ result = generator.generate(
+ app_model=app,
+ workflow=workflow,
+ user=self._get_user(user_id),
+ args={
+ 'inputs': tool_parameters,
+ 'files': files
+ },
+ invoke_from=self.runtime.invoke_from,
+ stream=False,
+ call_depth=self.workflow_call_depth + 1,
+ )
+
+ data = result.get('data', {})
+
+ if data.get('error'):
+ raise Exception(data.get('error'))
+
+ result = []
+
+ outputs = data.get('outputs', {})
+ outputs, files = self._extract_files(outputs)
+ for file in files:
+ result.append(self.create_file_var_message(file))
+
+ result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
+
+ return result
+
+ def _get_user(self, user_id: str) -> Union[EndUser, Account]:
+ """
+ get the user by user id
+ """
+
+ user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
+ if not user:
+ user = db.session.query(Account).filter(Account.id == user_id).first()
+
+ if not user:
+ raise ValueError('user not found')
+
+ return user
+
+ def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool':
+ """
+ fork a new tool with meta data
+
+ :param meta: the meta data of a tool call processing, tenant_id is required
+ :return: the new tool
+ """
+ return self.__class__(
+ identity=deepcopy(self.identity),
+ parameters=deepcopy(self.parameters),
+ description=deepcopy(self.description),
+ runtime=Tool.Runtime(**runtime),
+ workflow_app_id=self.workflow_app_id,
+ workflow_entities=self.workflow_entities,
+ workflow_call_depth=self.workflow_call_depth,
+ version=self.version,
+ label=self.label
+ )
+
+ def _get_workflow(self, app_id: str, version: str) -> Workflow:
+ """
+ get the workflow by app id and version
+ """
+ if not version:
+ workflow = db.session.query(Workflow).filter(
+ Workflow.app_id == app_id,
+ Workflow.version != 'draft'
+ ).order_by(Workflow.created_at.desc()).first()
+ else:
+ workflow = db.session.query(Workflow).filter(
+ Workflow.app_id == app_id,
+ Workflow.version == version
+ ).first()
+
+ if not workflow:
+ raise ValueError('workflow not found or not published')
+
+ return workflow
+
+ def _get_app(self, app_id: str) -> App:
+ """
+ get the app by app id
+ """
+ app = db.session.query(App).filter(App.id == app_id).first()
+ if not app:
+ raise ValueError('app not found')
+
+ return app
+
+ def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
+ """
+ transform the tool parameters
+
+ :param tool_parameters: the tool parameters
+ :return: tool_parameters, files
+ """
+ parameter_rules = self.get_all_runtime_parameters()
+ parameters_result = {}
+ files = []
+ for parameter in parameter_rules:
+ if parameter.type == ToolParameter.ToolParameterType.FILE:
+ file = tool_parameters.get(parameter.name)
+ if file:
+ try:
+ file_var_list = [FileVar(**f) for f in file]
+ for file_var in file_var_list:
+ file_dict = {
+ 'transfer_method': file_var.transfer_method.value,
+ 'type': file_var.type.value,
+ }
+ if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
+ file_dict['tool_file_id'] = file_var.related_id
+ elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
+ file_dict['upload_file_id'] = file_var.related_id
+ elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
+ file_dict['url'] = file_var.preview_url
+
+ files.append(file_dict)
+ except Exception as e:
+ logger.exception(e)
+ else:
+ parameters_result[parameter.name] = tool_parameters.get(parameter.name)
+
+ return parameters_result, files
+
+ def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
+ """
+ extract files from the result
+
+ :param result: the result
+ :return: the result, files
+ """
+ files = []
+ result = {}
+ for key, value in outputs.items():
+ if isinstance(value, list):
+ has_file = False
+ for item in value:
+ if isinstance(item, dict) and item.get('__variant') == 'FileVar':
+ try:
+ files.append(FileVar(**item))
+ has_file = True
+ except Exception as e:
+ pass
+ if has_file:
+ continue
+
+ result[key] = value
+
+ return result, files
\ No newline at end of file
diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py
index f96d7940bd..16fe9051e3 100644
--- a/api/core/tools/tool_engine.py
+++ b/api/core/tools/tool_engine.py
@@ -1,7 +1,10 @@
from copy import deepcopy
from datetime import datetime, timezone
+from mimetypes import guess_type
from typing import Union
+from yarl import URL
+
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@@ -17,6 +20,7 @@ from core.tools.errors import (
ToolProviderNotFoundError,
)
from core.tools.tool.tool import Tool
+from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from extensions.ext_database import db
from models.model import Message, MessageFile
@@ -115,7 +119,8 @@ class ToolEngine:
@staticmethod
def workflow_invoke(tool: Tool, tool_parameters: dict,
user_id: str, workflow_id: str,
- workflow_tool_callback: DifyWorkflowCallbackHandler) \
+ workflow_tool_callback: DifyWorkflowCallbackHandler,
+ workflow_call_depth: int) \
-> list[ToolInvokeMessage]:
"""
Workflow invokes the tool with the given arguments.
@@ -127,6 +132,9 @@ class ToolEngine:
tool_inputs=tool_parameters
)
+ if isinstance(tool, WorkflowTool):
+ tool.workflow_call_depth = workflow_call_depth + 1
+
response = tool.invoke(user_id, tool_parameters)
# hit the callback handler
@@ -195,8 +203,24 @@ class ToolEngine:
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
+ mimetype = None
+ if response.meta.get('mime_type'):
+ mimetype = response.meta.get('mime_type')
+ else:
+ try:
+ url = URL(response.message)
+ extension = url.suffix
+ guess_type_result, _ = guess_type(f'a{extension}')
+ if guess_type_result:
+ mimetype = guess_type_result
+ except Exception:
+ pass
+
+ if not mimetype:
+ mimetype = 'image/jpeg'
+
result.append(ToolInvokeMessageBinary(
- mimetype=response.meta.get('mime_type', 'octet/stream'),
+ mimetype=response.meta.get('mime_type', 'image/jpeg'),
url=response.message,
save_as=response.save_as,
))
diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py
index e21a2efedd..207f009eed 100644
--- a/api/core/tools/tool_file_manager.py
+++ b/api/core/tools/tool_file_manager.py
@@ -53,7 +53,7 @@ class ToolFileManager:
return False
current_time = int(time.time())
- return current_time - int(timestamp) <= 300 # expired after 5 minutes
+ return current_time - int(timestamp) <= current_app.config.get('FILES_ACCESS_TIMEOUT')
@staticmethod
def create_file_by_raw(user_id: str, tenant_id: str,
@@ -65,7 +65,7 @@ class ToolFileManager:
"""
extension = guess_extension(mimetype) or '.bin'
unique_name = uuid4().hex
- filename = f"/tools/{tenant_id}/{unique_name}{extension}"
+ filename = f"tools/{tenant_id}/{unique_name}{extension}"
storage.save(filename, file_binary)
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
@@ -90,7 +90,7 @@ class ToolFileManager:
mimetype = guess_type(file_url)[0] or 'octet/stream'
extension = guess_extension(mimetype) or '.bin'
unique_name = uuid4().hex
- filename = f"/tools/{tenant_id}/{unique_name}{extension}"
+ filename = f"tools/{tenant_id}/{unique_name}{extension}"
storage.save(filename, blob)
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py
new file mode 100644
index 0000000000..97788a7a07
--- /dev/null
+++ b/api/core/tools/tool_label_manager.py
@@ -0,0 +1,96 @@
+from core.tools.entities.values import default_tool_label_name_list
+from core.tools.provider.api_tool_provider import ApiToolProviderController
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+from core.tools.provider.tool_provider import ToolProviderController
+from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
+from extensions.ext_database import db
+from models.tools import ToolLabelBinding
+
+
+class ToolLabelManager:
+ @classmethod
+ def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
+ """
+ Filter tool labels
+ """
+ tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
+ return list(set(tool_labels))
+
+ @classmethod
+ def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
+ """
+ Update tool labels
+ """
+ labels = cls.filter_tool_labels(labels)
+
+ if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
+ provider_id = controller.provider_id
+ else:
+ raise ValueError('Unsupported tool type')
+
+ # delete old labels
+ db.session.query(ToolLabelBinding).filter(
+ ToolLabelBinding.tool_id == provider_id
+ ).delete()
+
+ # insert new labels
+ for label in labels:
+ db.session.add(ToolLabelBinding(
+ tool_id=provider_id,
+ tool_type=controller.provider_type.value,
+ label_name=label,
+ ))
+
+ db.session.commit()
+
+ @classmethod
+ def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
+ """
+ Get tool labels
+ """
+ if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
+ provider_id = controller.provider_id
+ elif isinstance(controller, BuiltinToolProviderController):
+ return controller.tool_labels
+ else:
+ raise ValueError('Unsupported tool type')
+
+ labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter(
+ ToolLabelBinding.tool_id == provider_id,
+ ToolLabelBinding.tool_type == controller.provider_type.value,
+ ).all()
+
+ return [label.label_name for label in labels]
+
+ @classmethod
+ def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
+ """
+ Get tools labels
+
+ :param tool_providers: list of tool providers
+
+ :return: dict of tool labels
+ :key: tool id
+ :value: list of tool labels
+ """
+ if not tool_providers:
+ return {}
+
+ for controller in tool_providers:
+ if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
+ raise ValueError('Unsupported tool type')
+
+ provider_ids = [controller.provider_id for controller in tool_providers]
+
+ labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter(
+ ToolLabelBinding.tool_id.in_(provider_ids)
+ ).all()
+
+ tool_labels = {
+ label.tool_id: [] for label in labels
+ }
+
+ for label in labels:
+ tool_labels[label.tool_id].append(label.label_name)
+
+ return tool_labels
\ No newline at end of file
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index a29bdfcd11..a0ca9f692a 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -9,32 +9,33 @@ from typing import Any, Union
from flask import current_app
from core.agent.entities import AgentToolEntity
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.helper.module_import_helper import load_single_subclass_from_source
from core.model_runtime.utils.encoders import jsonable_encoder
-from core.provider_manager import ProviderManager
-from core.tools import *
+from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
+ ToolInvokeFrom,
ToolParameter,
)
-from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolProviderNotFoundError
-from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
+from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
-from core.tools.provider.model_tool_provider import ModelToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
+from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ToolConfigurationManager,
ToolParameterConfigurationManager,
)
-from core.utils.module_import_helper import load_single_subclass_from_source
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
-from models.tools import ApiToolProvider, BuiltinToolProvider
-from services.tools_transform_service import ToolTransformService
+from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
+from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
@@ -101,7 +102,12 @@ class ToolManager:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@classmethod
- def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
+ def get_tool_runtime(cls, provider_type: str,
+ provider_id: str,
+ tool_name: str,
+ tenant_id: str,
+ invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
+ tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-> Union[BuiltinTool, ApiTool]:
"""
get the tool runtime
@@ -113,64 +119,76 @@ class ToolManager:
:return: the tool
"""
if provider_type == 'builtin':
- builtin_tool = cls.get_builtin_tool(provider_name, tool_name)
+ builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
# check if the builtin tool need credentials
- provider_controller = cls.get_builtin_provider(provider_name)
+ provider_controller = cls.get_builtin_provider(provider_id)
if not provider_controller.need_credentials:
- return builtin_tool.fork_tool_runtime(meta={
+ return builtin_tool.fork_tool_runtime(runtime={
'tenant_id': tenant_id,
'credentials': {},
+ 'invoke_from': invoke_from,
+ 'tool_invoke_from': tool_invoke_from,
})
# get credentials
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
- BuiltinToolProvider.provider == provider_name,
+ BuiltinToolProvider.provider == provider_id,
).first()
if builtin_provider is None:
- raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
+ raise ToolProviderNotFoundError(f'builtin provider {provider_id} not found')
# decrypt the credentials
credentials = builtin_provider.credentials
- controller = cls.get_builtin_provider(provider_name)
+ controller = cls.get_builtin_provider(provider_id)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
- return builtin_tool.fork_tool_runtime(meta={
+ return builtin_tool.fork_tool_runtime(runtime={
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
- 'runtime_parameters': {}
+ 'runtime_parameters': {},
+ 'invoke_from': invoke_from,
+ 'tool_invoke_from': tool_invoke_from,
})
elif provider_type == 'api':
if tenant_id is None:
raise ValueError('tenant id is required for api provider')
- api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name)
+ api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
# decrypt the credentials
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
- return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
+ return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
+ 'invoke_from': invoke_from,
+ 'tool_invoke_from': tool_invoke_from,
})
- elif provider_type == 'model':
- if tenant_id is None:
- raise ValueError('tenant id is required for model provider')
- # get model provider
- model_provider = cls.get_model_provider(tenant_id, provider_name)
+ elif provider_type == 'workflow':
+ workflow_provider = db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.tenant_id == tenant_id,
+ WorkflowToolProvider.id == provider_id
+ ).first()
- # get tool
- model_tool = model_provider.get_tool(tool_name)
+ if workflow_provider is None:
+ raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
- return model_tool.fork_tool_runtime(meta={
+ controller = ToolTransformService.workflow_provider_to_controller(
+ db_provider=workflow_provider
+ )
+
+ return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
'tenant_id': tenant_id,
- 'credentials': model_tool.model_configuration['model_instance'].credentials
+ 'credentials': {},
+ 'invoke_from': invoke_from,
+ 'tool_invoke_from': tool_invoke_from,
})
elif provider_type == 'app':
raise NotImplementedError('app provider not implemented')
@@ -196,44 +214,28 @@ class ToolManager:
raise ValueError(
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
- # convert tool parameter config to correct type
- try:
- if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER:
- # check if tool parameter is integer
- if isinstance(parameter_value, int):
- parameter_value = parameter_value
- elif isinstance(parameter_value, float):
- parameter_value = parameter_value
- elif isinstance(parameter_value, str):
- if '.' in parameter_value:
- parameter_value = float(parameter_value)
- else:
- parameter_value = int(parameter_value)
- elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN:
- parameter_value = bool(parameter_value)
- elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT,
- ToolParameter.ToolParameterType.STRING]:
- parameter_value = str(parameter_value)
- elif parameter_rule.type == ToolParameter.ToolParameterType:
- parameter_value = str(parameter_value)
- except Exception as e:
- raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type")
-
- return parameter_value
+ return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
@classmethod
- def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
+ def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
"""
get the agent tool runtime
"""
tool_entity = cls.get_tool_runtime(
- provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
+ provider_type=agent_tool.provider_type,
+ provider_id=agent_tool.provider_id,
tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
+ invoke_from=invoke_from,
+ tool_invoke_from=ToolInvokeFrom.AGENT
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
+ # check file types
+ if parameter.type == ToolParameter.ToolParameterType.FILE:
+ raise ValueError(f"file type parameter {parameter.name} not supported in agent")
+
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
@@ -253,15 +255,17 @@ class ToolManager:
return tool_entity
@classmethod
- def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
+ def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
"""
get the workflow tool runtime
"""
tool_entity = cls.get_tool_runtime(
provider_type=workflow_tool.provider_type,
- provider_name=workflow_tool.provider_id,
+ provider_id=workflow_tool.provider_id,
tool_name=workflow_tool.tool_name,
tenant_id=tenant_id,
+ invoke_from=invoke_from,
+ tool_invoke_from=ToolInvokeFrom.WORKFLOW
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
@@ -367,49 +371,6 @@ class ToolManager:
cls._builtin_providers = {}
cls._builtin_providers_loaded = False
- # @classmethod
- # def list_model_providers(cls, tenant_id: str = None) -> list[ModelToolProviderController]:
- # """
- # list all the model providers
-
- # :return: the list of the model providers
- # """
- # tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
- # # get configurations
- # model_configurations = ModelToolConfigurationManager.get_all_configuration()
- # # get all providers
- # provider_manager = ProviderManager()
- # configurations = provider_manager.get_configurations(tenant_id).values()
- # # get model providers
- # model_providers: list[ModelToolProviderController] = []
- # for configuration in configurations:
- # # all the model tool should be configurated
- # if configuration.provider.provider not in model_configurations:
- # continue
- # if not ModelToolProviderController.is_configuration_valid(configuration):
- # continue
- # model_providers.append(ModelToolProviderController.from_db(configuration))
-
- # return model_providers
-
- @classmethod
- def get_model_provider(cls, tenant_id: str, provider_name: str) -> ModelToolProviderController:
- """
- get the model provider
-
- :param provider_name: the name of the provider
-
- :return: the provider
- """
- # get configurations
- provider_manager = ProviderManager()
- configurations = provider_manager.get_configurations(tenant_id)
- configuration = configurations.get(provider_name)
- if configuration is None:
- raise ToolProviderNotFoundError(f'model provider {provider_name} not found')
-
- return ModelToolProviderController.from_db(configuration)
-
@classmethod
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
"""
@@ -419,7 +380,6 @@ class ToolManager:
:return: the label of the tool
"""
- cls._builtin_tools_labels
if len(cls._builtin_tools_labels) == 0:
# init the builtin providers
cls.load_builtin_providers_cache()
@@ -430,60 +390,91 @@ class ToolManager:
return cls._builtin_tools_labels[tool_name]
@classmethod
- def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
+ def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral) -> list[UserToolProvider]:
result_providers: dict[str, UserToolProvider] = {}
- # get builtin providers
- builtin_providers = cls.list_builtin_providers()
-
- # get db builtin providers
- db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
- filter(BuiltinToolProvider.tenant_id == tenant_id).all()
+ filters = []
+ if not typ:
+ filters.extend(['builtin', 'api', 'workflow'])
+ else:
+ filters.append(typ)
- find_db_builtin_provider = lambda provider: next(
- (x for x in db_builtin_providers if x.provider == provider),
- None
- )
+ if 'builtin' in filters:
- # append builtin providers
- for provider in builtin_providers:
- user_provider = ToolTransformService.builtin_provider_to_user_provider(
- provider_controller=provider,
- db_provider=find_db_builtin_provider(provider.identity.name),
- decrypt_credentials=False
+ # get builtin providers
+ builtin_providers = cls.list_builtin_providers()
+
+ # get db builtin providers
+ db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
+ filter(BuiltinToolProvider.tenant_id == tenant_id).all()
+
+ find_db_builtin_provider = lambda provider: next(
+ (x for x in db_builtin_providers if x.provider == provider),
+ None
)
- result_providers[provider.identity.name] = user_provider
+ # append builtin providers
+ for provider in builtin_providers:
+ user_provider = ToolTransformService.builtin_provider_to_user_provider(
+ provider_controller=provider,
+ db_provider=find_db_builtin_provider(provider.identity.name),
+ decrypt_credentials=False
+ )
- # # get model tool providers
- # model_providers = cls.list_model_providers(tenant_id=tenant_id)
- # # append model providers
- # for provider in model_providers:
- # user_provider = ToolTransformService.model_provider_to_user_provider(
- # db_provider=provider,
- # )
- # result_providers[f'model_provider.{provider.identity.name}'] = user_provider
+ result_providers[provider.identity.name] = user_provider
# get db api providers
- db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
- filter(ApiToolProvider.tenant_id == tenant_id).all()
- for db_api_provider in db_api_providers:
- provider_controller = ToolTransformService.api_provider_to_controller(
- db_provider=db_api_provider,
- )
- user_provider = ToolTransformService.api_provider_to_user_provider(
- provider_controller=provider_controller,
- db_provider=db_api_provider,
- decrypt_credentials=False
- )
- result_providers[db_api_provider.name] = user_provider
+ if 'api' in filters:
+ db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
+ filter(ApiToolProvider.tenant_id == tenant_id).all()
+
+ api_provider_controllers = [{
+ 'provider': provider,
+ 'controller': ToolTransformService.api_provider_to_controller(provider)
+ } for provider in db_api_providers]
+
+ # get labels
+ labels = ToolLabelManager.get_tools_labels([x['controller'] for x in api_provider_controllers])
+
+ for api_provider_controller in api_provider_controllers:
+ user_provider = ToolTransformService.api_provider_to_user_provider(
+ provider_controller=api_provider_controller['controller'],
+ db_provider=api_provider_controller['provider'],
+ decrypt_credentials=False,
+ labels=labels.get(api_provider_controller['controller'].provider_id, [])
+ )
+ result_providers[f'api_provider.{user_provider.name}'] = user_provider
+
+ if 'workflow' in filters:
+ # get workflow providers
+ workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \
+ filter(WorkflowToolProvider.tenant_id == tenant_id).all()
+
+ workflow_provider_controllers = []
+ for provider in workflow_providers:
+ try:
+ workflow_provider_controllers.append(
+ ToolTransformService.workflow_provider_to_controller(db_provider=provider)
+ )
+ except Exception as e:
+ # app has been deleted
+ pass
+
+ labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
+
+ for provider_controller in workflow_provider_controllers:
+ user_provider = ToolTransformService.workflow_provider_to_user_provider(
+ provider_controller=provider_controller,
+ labels=labels.get(provider_controller.provider_id, []),
+ )
+ result_providers[f'workflow_provider.{user_provider.name}'] = user_provider
return BuiltinToolProviderSort.sort(list(result_providers.values()))
@classmethod
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
- ApiBasedToolProviderController, dict[str, Any]]:
+ ApiToolProviderController, dict[str, Any]]:
"""
get the api provider
@@ -499,7 +490,7 @@ class ToolManager:
if provider is None:
raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
- controller = ApiBasedToolProviderController.from_db(
+ controller = ApiToolProviderController.from_db(
provider,
ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else
ApiProviderAuthType.NONE
@@ -530,7 +521,7 @@ class ToolManager:
credentials = {}
# package tool provider controller
- controller = ApiBasedToolProviderController.from_db(
+ controller = ApiToolProviderController.from_db(
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
# init tool configuration
@@ -547,6 +538,9 @@ class ToolManager:
"content": "\ud83d\ude01"
}
+ # add tool labels
+ labels = ToolLabelManager.get_tool_labels(controller)
+
return jsonable_encoder({
'schema_type': provider.schema_type,
'schema': provider.schema,
@@ -554,7 +548,9 @@ class ToolManager:
'icon': icon,
'description': provider.description,
'credentials': masked_credentials,
- 'privacy_policy': provider.privacy_policy
+ 'privacy_policy': provider.privacy_policy,
+ 'custom_disclaimer': provider.custom_disclaimer,
+ 'labels': labels,
})
@classmethod
@@ -586,6 +582,15 @@ class ToolManager:
"background": "#252525",
"content": "\ud83d\ude01"
}
+ elif provider_type == 'workflow':
+ provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.tenant_id == tenant_id,
+ WorkflowToolProvider.id == provider_id
+ ).first()
+ if provider is None:
+ raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
+
+ return json.loads(provider.icon)
else:
raise ValueError(f"provider type {provider_type} not found")
diff --git a/api/core/tools/utils/__init__.py b/api/core/tools/utils/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py
index 917f8411c4..b213879e96 100644
--- a/api/core/tools/utils/configuration.py
+++ b/api/core/tools/utils/configuration.py
@@ -1,16 +1,12 @@
-import os
from copy import deepcopy
-from typing import Any, Union
+from typing import Any
from pydantic import BaseModel
-from yaml import FullLoader, load
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.entities.tool_entities import (
- ModelToolConfiguration,
- ModelToolProviderConfiguration,
ToolParameter,
ToolProviderCredentials,
)
@@ -27,7 +23,7 @@ class ToolConfigurationManager(BaseModel):
deep copy credentials
"""
return deepcopy(credentials)
-
+
def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
@@ -43,9 +39,9 @@ class ToolConfigurationManager(BaseModel):
if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted
-
+
return credentials
-
+
def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
@@ -62,7 +58,7 @@ class ToolConfigurationManager(BaseModel):
if len(credentials[field_name]) > 6:
credentials[field_name] = \
credentials[field_name][:2] + \
- '*' * (len(credentials[field_name]) - 4) +\
+ '*' * (len(credentials[field_name]) - 4) + \
credentials[field_name][-2:]
else:
credentials[field_name] = '*' * len(credentials[field_name])
@@ -77,7 +73,7 @@ class ToolConfigurationManager(BaseModel):
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
- identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
+ identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cached_credentials = cache.get()
@@ -96,11 +92,11 @@ class ToolConfigurationManager(BaseModel):
cache.set(credentials)
return credentials
-
+
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
- identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
+ identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cache.delete()
@@ -120,7 +116,7 @@ class ToolParameterConfigurationManager(BaseModel):
deep copy parameters
"""
return deepcopy(parameters)
-
+
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
@@ -143,7 +139,7 @@ class ToolParameterConfigurationManager(BaseModel):
current_parameters.append(runtime_parameter)
return current_parameters
-
+
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
@@ -161,13 +157,13 @@ class ToolParameterConfigurationManager(BaseModel):
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = \
parameters[parameter.name][:2] + \
- '*' * (len(parameters[parameter.name]) - 4) +\
+ '*' * (len(parameters[parameter.name]) - 4) + \
parameters[parameter.name][-2:]
else:
parameters[parameter.name] = '*' * len(parameters[parameter.name])
return parameters
-
+
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
@@ -184,9 +180,9 @@ class ToolParameterConfigurationManager(BaseModel):
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
-
+
return parameters
-
+
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
@@ -194,7 +190,7 @@ class ToolParameterConfigurationManager(BaseModel):
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
- tenant_id=self.tenant_id,
+ tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
@@ -216,80 +212,18 @@ class ToolParameterConfigurationManager(BaseModel):
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except:
pass
-
+
if has_secret_input:
cache.set(parameters)
return parameters
-
+
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
- tenant_id=self.tenant_id,
+ tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id
)
cache.delete()
-
-class ModelToolConfigurationManager:
- """
- Model as tool configuration
- """
- _configurations: dict[str, ModelToolProviderConfiguration] = {}
- _model_configurations: dict[str, ModelToolConfiguration] = {}
- _inited = False
-
- @classmethod
- def _init_configuration(cls):
- """
- init configuration
- """
-
- absolute_path = os.path.abspath(os.path.dirname(__file__))
- model_tools_path = os.path.join(absolute_path, '..', 'model_tools')
-
- # get all .yaml file
- files = [f for f in os.listdir(model_tools_path) if f.endswith('.yaml')]
-
- for file in files:
- provider = file.split('.')[0]
- with open(os.path.join(model_tools_path, file), encoding='utf-8') as f:
- configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader))
- models = configurations.models or []
- for model in models:
- model_key = f'{provider}.{model.model}'
- cls._model_configurations[model_key] = model
-
- cls._configurations[provider] = configurations
- cls._inited = True
-
- @classmethod
- def get_configuration(cls, provider: str) -> Union[ModelToolProviderConfiguration, None]:
- """
- get configuration by provider
- """
- if not cls._inited:
- cls._init_configuration()
- return cls._configurations.get(provider, None)
-
- @classmethod
- def get_all_configuration(cls) -> dict[str, ModelToolProviderConfiguration]:
- """
- get all configurations
- """
- if not cls._inited:
- cls._init_configuration()
- return cls._configurations
-
- @classmethod
- def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolConfiguration, None]:
- """
- get model configuration
- """
- key = f'{provider}.{model}'
-
- if not cls._inited:
- cls._init_configuration()
-
- return cls._model_configurations.get(key, None)
\ No newline at end of file
diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py
index 3f456b4eb6..ef9e5b67ae 100644
--- a/api/core/tools/utils/message_transformer.py
+++ b/api/core/tools/utils/message_transformer.py
@@ -1,14 +1,15 @@
import logging
from mimetypes import guess_extension
+from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
- @staticmethod
- def transform_tool_invoke_messages(messages: list[ToolInvokeMessage],
+ @classmethod
+ def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
user_id: str,
tenant_id: str,
conversation_id: str) -> list[ToolInvokeMessage]:
@@ -62,7 +63,7 @@ class ToolFileMessageTransformer:
mimetype=mimetype
)
- url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
+ url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
# check if file is image
if 'image' in mimetype:
@@ -79,7 +80,30 @@ class ToolFileMessageTransformer:
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
+ elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
+ file_var: FileVar = message.meta.get('file_var')
+ if file_var:
+ if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
+ url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
+ if file_var.type == FileType.IMAGE:
+ result.append(ToolInvokeMessage(
+ type=ToolInvokeMessage.MessageType.IMAGE_LINK,
+ message=url,
+ save_as=message.save_as,
+ meta=message.meta.copy() if message.meta is not None else {},
+ ))
+ else:
+ result.append(ToolInvokeMessage(
+ type=ToolInvokeMessage.MessageType.LINK,
+ message=url,
+ save_as=message.save_as,
+ meta=message.meta.copy() if message.meta is not None else {},
+ ))
else:
result.append(message)
- return result
\ No newline at end of file
+ return result
+
+ @classmethod
+ def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
+ return f'/files/tools/{tool_file_id}{extension or ".bin"}'
\ No newline at end of file
diff --git a/api/core/tools/model/tool_model_manager.py b/api/core/tools/utils/model_invocation_utils.py
similarity index 88%
rename from api/core/tools/model/tool_model_manager.py
rename to api/core/tools/utils/model_invocation_utils.py
index e97d78d699..9e8ef47823 100644
--- a/api/core/tools/model/tool_model_manager.py
+++ b/api/core/tools/utils/model_invocation_utils.py
@@ -20,12 +20,14 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder
-from core.tools.model.errors import InvokeModelError
from extensions.ext_database import db
from models.tools import ToolModelInvoke
-class ToolModelManager:
+class InvokeModelError(Exception):
+ pass
+
+class ModelInvocationUtils:
@staticmethod
def get_max_llm_context_tokens(
tenant_id: str,
@@ -71,10 +73,8 @@ class ToolModelManager:
if not model_instance:
raise InvokeModelError('Model not found')
- llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
-
# get tokens
- tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages)
+ tokens = model_instance.get_llm_num_tokens(prompt_messages)
return tokens
@@ -106,13 +106,8 @@ class ToolModelManager:
tenant_id=tenant_id, model_type=ModelType.LLM,
)
- llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
-
- # get model credentials
- model_credentials = model_instance.credentials
-
# get prompt tokens
- prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages)
+ prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
model_parameters = {
'temperature': 0.8,
@@ -142,9 +137,7 @@ class ToolModelManager:
db.session.commit()
try:
- response: LLMResult = llm_model.invoke(
- model=model_instance.model,
- credentials=model_credentials,
+ response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[], stop=[], stream=False, user=user_id, callbacks=[]
@@ -174,4 +167,4 @@ class ToolModelManager:
db.session.commit()
- return response
\ No newline at end of file
+ return response
diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py
index a96d8a6b7c..40ae6c66d5 100644
--- a/api/core/tools/utils/parser.py
+++ b/api/core/tools/utils/parser.py
@@ -9,14 +9,14 @@ from requests import get
from yaml import YAMLError, safe_load
from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_bundle import ApiBasedToolBundle
+from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser:
@staticmethod
- def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
+ def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
@@ -145,7 +145,7 @@ class ApiBasedToolSchemaParser:
interface['operation']['operationId'] = f'{path}_{interface["method"]}'
- bundles.append(ApiBasedToolBundle(
+ bundles.append(ApiToolBundle(
server_url=server_url + interface['path'],
method=interface['method'],
summary=interface['operation']['description'] if 'description' in interface['operation'] else
@@ -176,7 +176,7 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.STRING
@staticmethod
- def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
+ def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
@@ -258,7 +258,7 @@ class ApiBasedToolSchemaParser:
return openapi
@staticmethod
- def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
+ def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
@@ -290,7 +290,7 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
@staticmethod
- def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiBasedToolBundle], str]:
+ def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle
diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py
new file mode 100644
index 0000000000..55535be930
--- /dev/null
+++ b/api/core/tools/utils/tool_parameter_converter.py
@@ -0,0 +1,66 @@
+from typing import Any
+
+from core.tools.entities.tool_entities import ToolParameter
+
+
+class ToolParameterConverter:
+ @staticmethod
+ def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str:
+ match parameter_type:
+ case ToolParameter.ToolParameterType.STRING \
+ | ToolParameter.ToolParameterType.SECRET_INPUT \
+ | ToolParameter.ToolParameterType.SELECT:
+ return 'string'
+
+ case ToolParameter.ToolParameterType.BOOLEAN:
+ return 'boolean'
+
+ case ToolParameter.ToolParameterType.NUMBER:
+ return 'number'
+
+ case _:
+ raise ValueError(f"Unsupported parameter type {parameter_type}")
+
+ @staticmethod
+ def cast_parameter_by_type(value: Any, parameter_type: str) -> Any:
+ # convert tool parameter config to correct type
+ try:
+ match parameter_type:
+ case ToolParameter.ToolParameterType.STRING \
+ | ToolParameter.ToolParameterType.SECRET_INPUT \
+ | ToolParameter.ToolParameterType.SELECT:
+ if value is None:
+ return ''
+ else:
+ return value if isinstance(value, str) else str(value)
+
+ case ToolParameter.ToolParameterType.BOOLEAN:
+ if value is None:
+ return False
+ elif isinstance(value, str):
+ # Allowed YAML boolean value strings: https://yaml.org/type/bool.html
+ # and also '0' for False and '1' for True
+ match value.lower():
+ case 'true' | 'yes' | 'y' | '1':
+ return True
+ case 'false' | 'no' | 'n' | '0':
+ return False
+ case _:
+ return bool(value)
+ else:
+ return value if isinstance(value, bool) else bool(value)
+
+ case ToolParameter.ToolParameterType.NUMBER:
+ if isinstance(value, int) | isinstance(value, float):
+ return value
+ elif isinstance(value, str):
+ if '.' in value:
+ return float(value)
+ else:
+ return int(value)
+
+ case _:
+ return str(value)
+
+ except Exception:
+ raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")
diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py
index 4c6fbb2780..4c69c6eddc 100644
--- a/api/core/tools/utils/web_reader_tool.py
+++ b/api/core/tools/utils/web_reader_tool.py
@@ -1,5 +1,6 @@
import hashlib
import json
+import mimetypes
import os
import re
import site
@@ -7,6 +8,7 @@ import subprocess
import tempfile
import unicodedata
from contextlib import contextmanager
+from urllib.parse import unquote
import requests
from bs4 import BeautifulSoup, CData, Comment, NavigableString
@@ -39,22 +41,34 @@ def get_url(url: str, user_agent: str = None) -> str:
}
if user_agent:
headers["User-Agent"] = user_agent
-
- supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
- response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 10))
+ main_content_type = None
+ supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
+ response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
if response.status_code != 200:
return "URL returned status code {}.".format(response.status_code)
# check content-type
- main_content_type = response.headers.get('Content-Type').split(';')[0].strip()
+ content_type = response.headers.get('Content-Type')
+ if content_type:
+ main_content_type = response.headers.get('Content-Type').split(';')[0].strip()
+ else:
+ content_disposition = response.headers.get('Content-Disposition', '')
+ filename_match = re.search(r'filename="([^"]+)"', content_disposition)
+ if filename_match:
+ filename = unquote(filename_match.group(1))
+ extension = re.search(r'\.(\w+)$', filename)
+ if extension:
+ main_content_type = mimetypes.guess_type(filename)[0]
+
if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return ExtractProcessor.load_from_url(url, return_text=True)
+ response = requests.get(url, headers=headers, allow_redirects=True, timeout=(120, 300))
a = extract_using_readabilipy(response.text)
if not a['plain_text'] or not a['plain_text'].strip():
@@ -118,17 +132,17 @@ def extract_using_readabilipy(html):
}
# Populate article fields from readability fields where present
if input_json:
- if "title" in input_json and input_json["title"]:
+ if input_json.get("title"):
article_json["title"] = input_json["title"]
- if "byline" in input_json and input_json["byline"]:
+ if input_json.get("byline"):
article_json["byline"] = input_json["byline"]
- if "date" in input_json and input_json["date"]:
+ if input_json.get("date"):
article_json["date"] = input_json["date"]
- if "content" in input_json and input_json["content"]:
+ if input_json.get("content"):
article_json["content"] = input_json["content"]
article_json["plain_content"] = plain_content(article_json["content"], False, False)
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
- if "textContent" in input_json and input_json["textContent"]:
+ if input_json.get("textContent"):
article_json["plain_text"] = input_json["textContent"]
article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py
new file mode 100644
index 0000000000..ff5505bbbf
--- /dev/null
+++ b/api/core/tools/utils/workflow_configuration_sync.py
@@ -0,0 +1,48 @@
+from core.app.app_config.entities import VariableEntity
+from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
+
+
+class WorkflowToolConfigurationUtils:
+ @classmethod
+ def check_parameter_configurations(cls, configurations: list[dict]):
+ """
+ check parameter configurations
+ """
+ for configuration in configurations:
+ if not WorkflowToolParameterConfiguration(**configuration):
+ raise ValueError('invalid parameter configuration')
+
+ @classmethod
+ def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
+ """
+ get workflow graph variables
+ """
+ nodes = graph.get('nodes', [])
+ start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None)
+
+ if not start_node:
+ return []
+
+ return [
+ VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', [])
+ ]
+
+ @classmethod
+ def check_is_synced(cls,
+ variables: list[VariableEntity],
+ tool_configurations: list[WorkflowToolParameterConfiguration]) -> None:
+ """
+ check is synced
+
+ raise ValueError if not synced
+ """
+ variable_names = [variable.variable for variable in variables]
+
+ if len(tool_configurations) != len(variables):
+ raise ValueError('parameter configuration mismatch, please republish the tool to update')
+
+ for parameter in tool_configurations:
+ if parameter.name not in variable_names:
+ raise ValueError('parameter configuration mismatch, please republish the tool to update')
+
+ return True
\ No newline at end of file
diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py
new file mode 100644
index 0000000000..22e4d3d128
--- /dev/null
+++ b/api/core/tools/utils/yaml_utils.py
@@ -0,0 +1,34 @@
+import logging
+import os
+
+import yaml
+from yaml import YAMLError
+
+
+def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
+ """
+ Safe loading a YAML file to a dict
+ :param file_path: the path of the YAML file
+ :param ignore_error:
+ if True, return empty dict if error occurs and the error will be logged in warning level
+ if False, raise error if error occurs
+ :return: a dict of the YAML content
+ """
+ try:
+ if not file_path or not os.path.exists(file_path):
+ raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')
+
+ with open(file_path, encoding='utf-8') as file:
+ try:
+ return yaml.safe_load(file)
+ except Exception as e:
+ raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
+ except FileNotFoundError as e:
+ logging.debug(f'Failed to load YAML file {file_path}: {e}')
+ return {}
+ except Exception as e:
+ if ignore_error:
+ logging.warning(f'Failed to load YAML file {file_path}: {e}')
+ return {}
+ else:
+ raise e
diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py
index dd5a30f611..3b0d51d868 100644
--- a/api/core/workflow/callbacks/base_workflow_callback.py
+++ b/api/core/workflow/callbacks/base_workflow_callback.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Optional
+from typing import Any, Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.workflow.entities.base_node_data_entities import BaseNodeData
@@ -71,6 +71,42 @@ class BaseWorkflowCallback(ABC):
Publish text chunk
"""
raise NotImplementedError
+
+ @abstractmethod
+ def on_workflow_iteration_started(self,
+ node_id: str,
+ node_type: NodeType,
+ node_run_index: int = 1,
+ node_data: Optional[BaseNodeData] = None,
+ inputs: dict = None,
+ predecessor_node_id: Optional[str] = None,
+ metadata: Optional[dict] = None) -> None:
+ """
+ Publish iteration started
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def on_workflow_iteration_next(self, node_id: str,
+ node_type: NodeType,
+ index: int,
+ node_run_index: int,
+ output: Optional[Any],
+ ) -> None:
+ """
+ Publish iteration next
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def on_workflow_iteration_completed(self, node_id: str,
+ node_type: NodeType,
+ node_run_index: int,
+ outputs: dict) -> None:
+ """
+ Publish iteration completed
+ """
+ raise NotImplementedError
@abstractmethod
def on_event(self, event: AppQueueEvent) -> None:
diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py
index fc6ee231ff..6bf0c11c7d 100644
--- a/api/core/workflow/entities/base_node_data_entities.py
+++ b/api/core/workflow/entities/base_node_data_entities.py
@@ -7,3 +7,16 @@ from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
+
+class BaseIterationNodeData(BaseNodeData):
+ start_node_id: str
+
+class BaseIterationState(BaseModel):
+ iteration_node_id: str
+ index: int
+ inputs: dict
+
+ class MetaData(BaseModel):
+ pass
+
+ metadata: MetaData
\ No newline at end of file
diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py
index 7eb9488792..ae86463407 100644
--- a/api/core/workflow/entities/node_entities.py
+++ b/api/core/workflow/entities/node_entities.py
@@ -21,7 +21,11 @@ class NodeType(Enum):
QUESTION_CLASSIFIER = 'question-classifier'
HTTP_REQUEST = 'http-request'
TOOL = 'tool'
+ VARIABLE_AGGREGATOR = 'variable-aggregator'
VARIABLE_ASSIGNER = 'variable-assigner'
+ LOOP = 'loop'
+ ITERATION = 'iteration'
+ PARAMETER_EXTRACTOR = 'parameter-extractor'
@classmethod
def value_of(cls, value: str) -> 'NodeType':
@@ -68,6 +72,8 @@ class NodeRunMetadataKey(Enum):
TOTAL_PRICE = 'total_price'
CURRENCY = 'currency'
TOOL_INFO = 'tool_info'
+ ITERATION_ID = 'iteration_id'
+ ITERATION_INDEX = 'iteration_index'
class NodeRunResult(BaseModel):
diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py
index 690bdddaf6..c04770616c 100644
--- a/api/core/workflow/entities/variable_pool.py
+++ b/api/core/workflow/entities/variable_pool.py
@@ -90,3 +90,12 @@ class VariablePool:
raise ValueError(f'Invalid value type: {target_value_type.value}')
return value
+
+ def clear_node_variables(self, node_id: str) -> None:
+ """
+ Clear node variables
+ :param node_id: node id
+ :return:
+ """
+ if node_id in self.variables_mapping:
+ self.variables_mapping.pop(node_id)
\ No newline at end of file
diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py
index e1c5eb6752..9b35b8df8a 100644
--- a/api/core/workflow/entities/workflow_entities.py
+++ b/api/core/workflow/entities/workflow_entities.py
@@ -1,5 +1,9 @@
from typing import Optional
+from pydantic import BaseModel
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode, UserFrom
@@ -22,6 +26,9 @@ class WorkflowRunState:
workflow_type: WorkflowType
user_id: str
user_from: UserFrom
+ invoke_from: InvokeFrom
+
+ workflow_call_depth: int
start_at: float
variable_pool: VariablePool
@@ -30,20 +37,37 @@ class WorkflowRunState:
workflow_nodes_and_results: list[WorkflowNodeAndResult]
+ class NodeRun(BaseModel):
+ node_id: str
+ iteration_node_id: str
+
+ workflow_node_runs: list[NodeRun]
+ workflow_node_steps: int
+
+ current_iteration_state: Optional[BaseIterationState]
+
def __init__(self, workflow: Workflow,
start_at: float,
variable_pool: VariablePool,
user_id: str,
- user_from: UserFrom):
+ user_from: UserFrom,
+ invoke_from: InvokeFrom,
+ workflow_call_depth: int):
self.workflow_id = workflow.id
self.tenant_id = workflow.tenant_id
self.app_id = workflow.app_id
self.workflow_type = WorkflowType.value_of(workflow.type)
self.user_id = user_id
self.user_from = user_from
+ self.invoke_from = invoke_from
+ self.workflow_call_depth = workflow_call_depth
self.start_at = start_at
self.variable_pool = variable_pool
self.total_tokens = 0
self.workflow_nodes_and_results = []
+
+ self.current_iteration_state = None
+ self.workflow_node_steps = 1
+ self.workflow_node_runs = []
\ No newline at end of file
diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py
index 7cc9c6ee3d..fa7d6424f1 100644
--- a/api/core/workflow/nodes/base_node.py
+++ b/api/core/workflow/nodes/base_node.py
@@ -2,8 +2,9 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
+from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
-from core.workflow.entities.base_node_data_entities import BaseNodeData
+from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
@@ -37,6 +38,9 @@ class BaseNode(ABC):
workflow_id: str
user_id: str
user_from: UserFrom
+ invoke_from: InvokeFrom
+
+ workflow_call_depth: int
node_id: str
node_data: BaseNodeData
@@ -49,13 +53,17 @@ class BaseNode(ABC):
workflow_id: str,
user_id: str,
user_from: UserFrom,
+ invoke_from: InvokeFrom,
config: dict,
- callbacks: list[BaseWorkflowCallback] = None) -> None:
+ callbacks: list[BaseWorkflowCallback] = None,
+ workflow_call_depth: int = 0) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.user_id = user_id
self.user_from = user_from
+ self.invoke_from = invoke_from
+ self.workflow_call_depth = workflow_call_depth
self.node_id = config.get("id")
if not self.node_id:
@@ -140,3 +148,38 @@ class BaseNode(ABC):
:return:
"""
return self._node_type
+
+class BaseIterationNode(BaseNode):
+ @abstractmethod
+ def _run(self, variable_pool: VariablePool) -> BaseIterationState:
+ """
+ Run node
+ :param variable_pool: variable pool
+ :return:
+ """
+ raise NotImplementedError
+
+ def run(self, variable_pool: VariablePool) -> BaseIterationState:
+ """
+ Run node entry
+ :param variable_pool: variable pool
+ :return:
+ """
+ return self._run(variable_pool=variable_pool)
+
+ def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
+ """
+ Get next iteration start node id based on the graph.
+ :param graph: graph
+ :return: next node id
+ """
+ return self._get_next_iteration(variable_pool, state)
+
+ @abstractmethod
+ def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
+ """
+ Get next iteration start node id based on the graph.
+ :param graph: graph
+ :return: next node id
+ """
+ raise NotImplementedError
diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py
index 2c1529f492..610a23e704 100644
--- a/api/core/workflow/nodes/code/code_node.py
+++ b/api/core/workflow/nodes/code/code_node.py
@@ -1,7 +1,10 @@
import os
from typing import Optional, Union, cast
-from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
+from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
+from core.helper.code_executor.code_node_provider import CodeNodeProvider
+from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
+from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
@@ -17,16 +20,6 @@ MAX_STRING_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_STRING_ARRAY_LENGTH', '30
MAX_OBJECT_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_OBJECT_ARRAY_LENGTH', '30'))
MAX_NUMBER_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_NUMBER_ARRAY_LENGTH', '1000'))
-JAVASCRIPT_DEFAULT_CODE = """function main({arg1, arg2}) {
- return {
- result: arg1 + arg2
- }
-}"""
-
-PYTHON_DEFAULT_CODE = """def main(arg1: int, arg2: int) -> dict:
- return {
- "result": arg1 + arg2,
- }"""
class CodeNode(BaseNode):
_node_data_cls = CodeNodeData
@@ -39,54 +32,15 @@ class CodeNode(BaseNode):
:param filters: filter by node config parameters.
:return:
"""
- if filters and filters.get("code_language") == "javascript":
- return {
- "type": "code",
- "config": {
- "variables": [
- {
- "variable": "arg1",
- "value_selector": []
- },
- {
- "variable": "arg2",
- "value_selector": []
- }
- ],
- "code_language": "javascript",
- "code": JAVASCRIPT_DEFAULT_CODE,
- "outputs": {
- "result": {
- "type": "string",
- "children": None
- }
- }
- }
- }
+ code_language = CodeLanguage.PYTHON3
+ if filters:
+ code_language = (filters.get("code_language", CodeLanguage.PYTHON3))
- return {
- "type": "code",
- "config": {
- "variables": [
- {
- "variable": "arg1",
- "value_selector": []
- },
- {
- "variable": "arg2",
- "value_selector": []
- }
- ],
- "code_language": "python3",
- "code": PYTHON_DEFAULT_CODE,
- "outputs": {
- "result": {
- "type": "string",
- "children": None
- }
- }
- }
- }
+ providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
+ code_provider: type[CodeNodeProvider] = next(p for p in providers
+ if p.is_accept_language(code_language))
+
+ return code_provider.get_default_config()
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
@@ -115,7 +69,8 @@ class CodeNode(BaseNode):
result = CodeExecutor.execute_workflow_code_template(
language=code_language,
code=code,
- inputs=variables
+ inputs=variables,
+ dependencies=node_data.dependencies
)
# Transform result
diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py
index 555bb3918e..03044268ab 100644
--- a/api/core/workflow/nodes/code/entities.py
+++ b/api/core/workflow/nodes/code/entities.py
@@ -2,6 +2,8 @@ from typing import Literal, Optional
from pydantic import BaseModel
+from core.helper.code_executor.code_executor import CodeLanguage
+from core.helper.code_executor.entities import CodeDependency
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
@@ -15,6 +17,7 @@ class CodeNodeData(BaseNodeData):
children: Optional[dict[str, 'Output']]
variables: list[VariableSelector]
- code_language: Literal['python3', 'javascript']
+ code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
code: str
- outputs: dict[str, Output]
\ No newline at end of file
+ outputs: dict[str, Output]
+ dependencies: Optional[list[CodeDependency]] = None
\ No newline at end of file
diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py
index d88ad999b7..4a81a4176d 100644
--- a/api/core/workflow/nodes/http_request/entities.py
+++ b/api/core/workflow/nodes/http_request/entities.py
@@ -1,9 +1,13 @@
+import os
from typing import Literal, Optional, Union
from pydantic import BaseModel, validator
from core.workflow.entities.base_node_data_entities import BaseNodeData
+MAX_CONNECT_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_CONNECT_TIMEOUT', '300'))
+MAX_READ_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_READ_TIMEOUT', '600'))
+MAX_WRITE_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_WRITE_TIMEOUT', '600'))
class HttpRequestNodeData(BaseNodeData):
"""
@@ -36,9 +40,9 @@ class HttpRequestNodeData(BaseNodeData):
data: Union[None, str]
class Timeout(BaseModel):
- connect: int
- read: int
- write: int
+ connect: Optional[int] = MAX_CONNECT_TIMEOUT
+ read: Optional[int] = MAX_READ_TIMEOUT
+ write: Optional[int] = MAX_WRITE_TIMEOUT
method: Literal['get', 'post', 'put', 'patch', 'delete', 'head']
url: str
@@ -46,4 +50,5 @@ class HttpRequestNodeData(BaseNodeData):
headers: str
params: str
body: Optional[Body]
- timeout: Optional[Timeout]
\ No newline at end of file
+ timeout: Optional[Timeout]
+ mask_authorization_header: Optional[bool] = True
diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py
index c2beb7a383..10002216f1 100644
--- a/api/core/workflow/nodes/http_request/http_executor.py
+++ b/api/core/workflow/nodes/http_request/http_executor.py
@@ -1,4 +1,5 @@
import json
+import os
from copy import deepcopy
from random import randint
from typing import Any, Optional, Union
@@ -13,10 +14,10 @@ from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.http_request.entities import HttpRequestNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
-MAX_BINARY_SIZE = 1024 * 1024 * 10 # 10MB
-READABLE_MAX_BINARY_SIZE = '10MB'
-MAX_TEXT_SIZE = 1024 * 1024 // 10 # 0.1MB
-READABLE_MAX_TEXT_SIZE = '0.1MB'
+MAX_BINARY_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_BINARY_SIZE', 1024 * 1024 * 10)) # 10MB
+READABLE_MAX_BINARY_SIZE = f'{MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
+MAX_TEXT_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_TEXT_SIZE', 1024 * 1024)) # 1MB
+READABLE_MAX_TEXT_SIZE = f'{MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
class HttpExecutorResponse:
@@ -24,18 +25,10 @@ class HttpExecutorResponse:
response: Union[httpx.Response, requests.Response]
def __init__(self, response: Union[httpx.Response, requests.Response] = None):
- """
- init
- """
- headers = {}
- if isinstance(response, httpx.Response):
+ self.headers = {}
+ if isinstance(response, httpx.Response | requests.Response):
for k, v in response.headers.items():
- headers[k] = v
- elif isinstance(response, requests.Response):
- for k, v in response.headers.items():
- headers[k] = v
-
- self.headers = headers
+ self.headers[k] = v
self.response = response
@property
@@ -45,21 +38,11 @@ class HttpExecutorResponse:
"""
content_type = self.get_content_type()
file_content_types = ['image', 'audio', 'video']
- for v in file_content_types:
- if v in content_type:
- return True
-
- return False
+
+ return any(v in content_type for v in file_content_types)
def get_content_type(self) -> str:
- """
- get content type
- """
- for key, val in self.headers.items():
- if key.lower() == 'content-type':
- return val
-
- return ''
+ return self.headers.get('content-type')
def extract_file(self) -> tuple[str, bytes]:
"""
@@ -67,29 +50,25 @@ class HttpExecutorResponse:
"""
if self.is_file:
return self.get_content_type(), self.body
-
+
return '', b''
-
+
@property
def content(self) -> str:
"""
get content
"""
- if isinstance(self.response, httpx.Response):
- return self.response.text
- elif isinstance(self.response, requests.Response):
+ if isinstance(self.response, httpx.Response | requests.Response):
return self.response.text
else:
raise ValueError(f'Invalid response type {type(self.response)}')
-
+
@property
def body(self) -> bytes:
"""
get body
"""
- if isinstance(self.response, httpx.Response):
- return self.response.content
- elif isinstance(self.response, requests.Response):
+ if isinstance(self.response, httpx.Response | requests.Response):
return self.response.content
else:
raise ValueError(f'Invalid response type {type(self.response)}')
@@ -99,20 +78,18 @@ class HttpExecutorResponse:
"""
get status code
"""
- if isinstance(self.response, httpx.Response):
- return self.response.status_code
- elif isinstance(self.response, requests.Response):
+ if isinstance(self.response, httpx.Response | requests.Response):
return self.response.status_code
else:
raise ValueError(f'Invalid response type {type(self.response)}')
-
+
@property
def size(self) -> int:
"""
get size
"""
return len(self.body)
-
+
@property
def readable_size(self) -> str:
"""
@@ -138,10 +115,8 @@ class HttpExecutor:
variable_selectors: list[VariableSelector]
timeout: HttpRequestNodeData.Timeout
- def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout, variable_pool: Optional[VariablePool] = None):
- """
- init
- """
+ def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout,
+ variable_pool: Optional[VariablePool] = None):
self.server_url = node_data.url
self.method = node_data.method
self.authorization = node_data.authorization
@@ -155,7 +130,8 @@ class HttpExecutor:
self.variable_selectors = []
self._init_template(node_data, variable_pool)
- def _is_json_body(self, body: HttpRequestNodeData.Body):
+ @staticmethod
+ def _is_json_body(body: HttpRequestNodeData.Body):
"""
check if body is json
"""
@@ -165,55 +141,48 @@ class HttpExecutor:
return True
except:
return False
-
+
return False
+ @staticmethod
+ def _to_dict(convert_item: str, convert_text: str, maxsplit: int = -1):
+ """
+ Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}`
+ :param convert_item: A label for what item to be converted, params, headers or body.
+ :param convert_text: The string containing key-value pairs separated by '\n'.
+ :param maxsplit: The maximum number of splits allowed for the ':' character in each key-value pair. Default is -1 (no limit).
+ :return: A dictionary containing the key-value pairs from the input string.
+ """
+ kv_paris = convert_text.split('\n')
+ result = {}
+ for kv in kv_paris:
+ if not kv.strip():
+ continue
+
+ kv = kv.split(':', maxsplit=maxsplit)
+ if len(kv) >= 3:
+ k, v = kv[0], ":".join(kv[1:])
+ elif len(kv) == 2:
+ k, v = kv
+ elif len(kv) == 1:
+ k, v = kv[0], ''
+ else:
+ raise ValueError(f'Invalid {convert_item} {kv}')
+ result[k.strip()] = v
+ return result
+
def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
- """
- init template
- """
- variable_selectors = []
# extract all template in url
self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
# extract all template in params
params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
-
- # fill in params
- kv_paris = params.split('\n')
- for kv in kv_paris:
- if not kv.strip():
- continue
-
- kv = kv.split(':')
- if len(kv) == 2:
- k, v = kv
- elif len(kv) == 1:
- k, v = kv[0], ''
- else:
- raise ValueError(f'Invalid params {kv}')
-
- self.params[k.strip()] = v
+ self.params = self._to_dict("params", params)
# extract all template in headers
headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
-
- # fill in headers
- kv_paris = headers.split('\n')
- for kv in kv_paris:
- if not kv.strip():
- continue
-
- kv = kv.split(':')
- if len(kv) == 2:
- k, v = kv
- elif len(kv) == 1:
- k, v = kv[0], ''
- else:
- raise ValueError(f'Invalid headers {kv}')
-
- self.headers[k.strip()] = v.strip()
+ self.headers = self._to_dict("headers", headers)
# extract all template in body
body_data_variable_selectors = []
@@ -231,18 +200,7 @@ class HttpExecutor:
self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
if node_data.body.type in ['form-data', 'x-www-form-urlencoded']:
- body = {}
- kv_paris = body_data.split('\n')
- for kv in kv_paris:
- if not kv.strip():
- continue
- kv = kv.split(':')
- if len(kv) == 2:
- body[kv[0].strip()] = kv[1]
- elif len(kv) == 1:
- body[kv[0].strip()] = ''
- else:
- raise ValueError(f'Invalid body {kv}')
+ body = self._to_dict("body", body_data, 1)
if node_data.body.type == 'form-data':
self.files = {
@@ -261,14 +219,14 @@ class HttpExecutor:
self.variable_selectors = (server_url_variable_selectors + params_variable_selectors
+ headers_variable_selectors + body_data_variable_selectors)
-
+
def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.authorization)
headers = deepcopy(self.headers) or {}
if self.authorization.type == 'api-key':
if self.authorization.config.api_key is None:
raise ValueError('api_key is required')
-
+
if not self.authorization.config.header:
authorization.config.header = 'Authorization'
@@ -278,9 +236,9 @@ class HttpExecutor:
headers[authorization.config.header] = f'Basic {authorization.config.api_key}'
elif self.authorization.config.type == 'custom':
headers[authorization.config.header] = authorization.config.api_key
-
+
return headers
-
+
def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse:
"""
validate the response
@@ -289,21 +247,22 @@ class HttpExecutor:
executor_response = HttpExecutorResponse(response)
else:
raise ValueError(f'Invalid response type {type(response)}')
-
+
if executor_response.is_file:
if executor_response.size > MAX_BINARY_SIZE:
- raise ValueError(f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.')
+ raise ValueError(
+ f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.')
else:
if executor_response.size > MAX_TEXT_SIZE:
- raise ValueError(f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.')
-
+ raise ValueError(
+ f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.')
+
return executor_response
-
+
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
"""
do http request depending on api bundle
"""
- # do http request
kwargs = {
'url': self.server_url,
'headers': headers,
@@ -312,25 +271,14 @@ class HttpExecutor:
'follow_redirects': True
}
- if self.method == 'get':
- response = ssrf_proxy.get(**kwargs)
- elif self.method == 'post':
- response = ssrf_proxy.post(data=self.body, files=self.files, **kwargs)
- elif self.method == 'put':
- response = ssrf_proxy.put(data=self.body, files=self.files, **kwargs)
- elif self.method == 'delete':
- response = ssrf_proxy.delete(data=self.body, files=self.files, **kwargs)
- elif self.method == 'patch':
- response = ssrf_proxy.patch(data=self.body, files=self.files, **kwargs)
- elif self.method == 'head':
- response = ssrf_proxy.head(**kwargs)
- elif self.method == 'options':
- response = ssrf_proxy.options(**kwargs)
+ if self.method in ('get', 'head', 'options'):
+ response = getattr(ssrf_proxy, self.method)(**kwargs)
+ elif self.method in ('post', 'put', 'delete', 'patch'):
+ response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
else:
raise ValueError(f'Invalid http method {self.method}')
-
return response
-
+
def invoke(self) -> HttpExecutorResponse:
"""
invoke http request
@@ -343,8 +291,8 @@ class HttpExecutor:
# validate response
return self._validate_and_parse_response(response)
-
- def to_raw_request(self) -> str:
+
+ def to_raw_request(self, mask_authorization_header: Optional[bool] = True) -> str:
"""
convert to raw request
"""
@@ -356,6 +304,17 @@ class HttpExecutor:
headers = self._assembling_headers()
for k, v in headers.items():
+ if mask_authorization_header:
+ # get authorization header
+ if self.authorization.type == 'api-key':
+ authorization_header = 'Authorization'
+ if self.authorization.config and self.authorization.config.header:
+ authorization_header = self.authorization.config.header
+
+ if k.lower() == authorization_header.lower():
+ raw_request += f'{k}: {"*" * len(v)}\n'
+ continue
+
raw_request += f'{k}: {v}\n'
raw_request += '\n'
diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py
index cba1a11a8a..d983a30695 100644
--- a/api/core/workflow/nodes/http_request/http_request_node.py
+++ b/api/core/workflow/nodes/http_request/http_request_node.py
@@ -1,5 +1,4 @@
import logging
-import os
from mimetypes import guess_extension
from os import path
from typing import cast
@@ -9,14 +8,15 @@ from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
-from core.workflow.nodes.http_request.entities import HttpRequestNodeData
+from core.workflow.nodes.http_request.entities import (
+ MAX_CONNECT_TIMEOUT,
+ MAX_READ_TIMEOUT,
+ MAX_WRITE_TIMEOUT,
+ HttpRequestNodeData,
+)
from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
from models.workflow import WorkflowNodeExecutionStatus
-MAX_CONNECT_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_CONNECT_TIMEOUT', '300'))
-MAX_READ_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_READ_TIMEOUT', '600'))
-MAX_WRITE_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_WRITE_TIMEOUT', '600'))
-
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeData.Timeout(connect=min(10, MAX_CONNECT_TIMEOUT),
read=min(60, MAX_READ_TIMEOUT),
write=min(20, MAX_WRITE_TIMEOUT))
@@ -63,7 +63,9 @@ class HttpRequestNode(BaseNode):
process_data = {}
if http_executor:
process_data = {
- 'request': http_executor.to_raw_request(),
+ 'request': http_executor.to_raw_request(
+ mask_authorization_header=node_data.mask_authorization_header
+ ),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@@ -82,7 +84,9 @@ class HttpRequestNode(BaseNode):
'files': files,
},
process_data={
- 'request': http_executor.to_raw_request(),
+ 'request': http_executor.to_raw_request(
+ mask_authorization_header=node_data.mask_authorization_header
+ ),
}
)
@@ -91,8 +95,14 @@ class HttpRequestNode(BaseNode):
if timeout is None:
return HTTP_REQUEST_DEFAULT_TIMEOUT
+ if timeout.connect is None:
+ timeout.connect = HTTP_REQUEST_DEFAULT_TIMEOUT.connect
timeout.connect = min(timeout.connect, MAX_CONNECT_TIMEOUT)
+ if timeout.read is None:
+ timeout.read = HTTP_REQUEST_DEFAULT_TIMEOUT.read
timeout.read = min(timeout.read, MAX_READ_TIMEOUT)
+ if timeout.write is None:
+ timeout.write = HTTP_REQUEST_DEFAULT_TIMEOUT.write
timeout.write = min(timeout.write, MAX_WRITE_TIMEOUT)
return timeout
diff --git a/api/core/workflow/nodes/iteration/__init__.py b/api/core/workflow/nodes/iteration/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py
new file mode 100644
index 0000000000..c85aa66c7b
--- /dev/null
+++ b/api/core/workflow/nodes/iteration/entities.py
@@ -0,0 +1,39 @@
+from typing import Any, Optional
+
+from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
+
+
+class IterationNodeData(BaseIterationNodeData):
+ """
+ Iteration Node Data.
+ """
+ parent_loop_id: Optional[str] # redundant field, not used currently
+ iterator_selector: list[str] # variable selector
+ output_selector: list[str] # output selector
+
+class IterationState(BaseIterationState):
+ """
+ Iteration State.
+ """
+ outputs: list[Any] = None
+ current_output: Optional[Any] = None
+
+ class MetaData(BaseIterationState.MetaData):
+ """
+ Data.
+ """
+ iterator_length: int
+
+ def get_last_output(self) -> Optional[Any]:
+ """
+ Get last output.
+ """
+ if self.outputs:
+ return self.outputs[-1]
+ return None
+
+ def get_current_output(self) -> Optional[Any]:
+ """
+ Get current output.
+ """
+ return self.current_output
\ No newline at end of file
diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py
new file mode 100644
index 0000000000..12d792f297
--- /dev/null
+++ b/api/core/workflow/nodes/iteration/iteration_node.py
@@ -0,0 +1,119 @@
+from typing import cast
+
+from core.model_runtime.utils.encoders import jsonable_encoder
+from core.workflow.entities.base_node_data_entities import BaseIterationState
+from core.workflow.entities.node_entities import NodeRunResult, NodeType
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.nodes.base_node import BaseIterationNode
+from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
+from models.workflow import WorkflowNodeExecutionStatus
+
+
+class IterationNode(BaseIterationNode):
+ """
+ Iteration Node.
+ """
+ _node_data_cls = IterationNodeData
+ _node_type = NodeType.ITERATION
+
+ def _run(self, variable_pool: VariablePool) -> BaseIterationState:
+ """
+ Run the node.
+ """
+ iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector)
+
+ if not isinstance(iterator, list):
+ raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
+
+ state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
+ 'iterator_selector': iterator
+ }, outputs=[], metadata=IterationState.MetaData(
+ iterator_length=len(iterator) if iterator is not None else 0
+ ))
+
+ self._set_current_iteration_variable(variable_pool, state)
+ return state
+
+ def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
+ """
+ Get next iteration start node id based on the graph.
+ :param graph: graph
+ :return: next node id
+ """
+ # resolve current output
+ self._resolve_current_output(variable_pool, state)
+ # move to next iteration
+ self._next_iteration(variable_pool, state)
+
+ node_data = cast(IterationNodeData, self.node_data)
+ if self._reached_iteration_limit(variable_pool, state):
+ return NodeRunResult(
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
+ outputs={
+ 'output': jsonable_encoder(state.outputs)
+ }
+ )
+
+ return node_data.start_node_id
+
+ def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
+ """
+ Set current iteration variable.
+ :variable_pool: variable pool
+ """
+ node_data = cast(IterationNodeData, self.node_data)
+
+ variable_pool.append_variable(self.node_id, ['index'], state.index)
+ # get the iterator value
+ iterator = variable_pool.get_variable_value(node_data.iterator_selector)
+
+ if iterator is None or not isinstance(iterator, list):
+ return
+
+ if state.index < len(iterator):
+ variable_pool.append_variable(self.node_id, ['item'], iterator[state.index])
+
+ def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
+ """
+ Move to next iteration.
+ :param variable_pool: variable pool
+ """
+ state.index += 1
+ self._set_current_iteration_variable(variable_pool, state)
+
+ def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
+ """
+ Check if iteration limit is reached.
+ :return: True if iteration limit is reached, False otherwise
+ """
+ node_data = cast(IterationNodeData, self.node_data)
+ iterator = variable_pool.get_variable_value(node_data.iterator_selector)
+
+ if iterator is None or not isinstance(iterator, list):
+ return True
+
+ return state.index >= len(iterator)
+
+ def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
+ """
+ Resolve current output.
+ :param variable_pool: variable pool
+ """
+ output_selector = cast(IterationNodeData, self.node_data).output_selector
+ output = variable_pool.get_variable_value(output_selector)
+ # clear the output for this iteration
+ variable_pool.append_variable(self.node_id, output_selector[1:], None)
+ state.current_output = output
+ if output is not None:
+ state.outputs.append(output)
+
+ @classmethod
+ def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
+ """
+ Extract variable selector to variable mapping
+ :param node_data: node data
+ :return:
+ """
+ return {
+ 'input_selector': node_data.iterator_selector,
+ }
\ No newline at end of file
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index be3cec9152..1a0f3b0495 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -1,5 +1,7 @@
from typing import Any, cast
+from sqlalchemy import func
+
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
@@ -73,30 +75,33 @@ class KnowledgeRetrievalNode(BaseNode):
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[
dict[str, Any]]:
- """
- A dataset tool is a tool that can be used to retrieve information from a dataset
- :param node_data: node data
- :param query: query
- """
- tools = []
available_datasets = []
dataset_ids = node_data.dataset_ids
- for dataset_id in dataset_ids:
- # get dataset from dataset id
- dataset = db.session.query(Dataset).filter(
- Dataset.tenant_id == self.tenant_id,
- Dataset.id == dataset_id
- ).first()
+ # Subquery: Count the number of available documents for each dataset
+ subquery = db.session.query(
+ Document.dataset_id,
+ func.count(Document.id).label('available_document_count')
+ ).filter(
+ Document.indexing_status == 'completed',
+ Document.enabled == True,
+ Document.archived == False,
+ Document.dataset_id.in_(dataset_ids)
+ ).group_by(Document.dataset_id).having(
+ func.count(Document.id) > 0
+ ).subquery()
+
+ results = db.session.query(Dataset).join(
+ subquery, Dataset.id == subquery.c.dataset_id
+ ).filter(
+ Dataset.tenant_id == self.tenant_id,
+ Dataset.id.in_(dataset_ids)
+ ).all()
+
+ for dataset in results:
# pass if dataset is not available
if not dataset:
continue
-
- # pass if dataset is not available
- if (dataset and dataset.available_document_count == 0
- and dataset.available_document_count == 0):
- continue
-
available_datasets.append(dataset)
all_documents = []
dataset_retrieval = DatasetRetrieval()
@@ -143,10 +148,9 @@ class KnowledgeRetrievalNode(BaseNode):
if all_documents:
document_score_list = {}
for item in all_documents:
- if 'score' in item.metadata and item.metadata['score']:
+ if item.metadata.get('score'):
document_score_list[item.metadata['doc_id']] = item.metadata['score']
- document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
@@ -160,11 +164,6 @@ class KnowledgeRetrievalNode(BaseNode):
sorted_segments = sorted(segments,
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
float('inf')))
- for segment in sorted_segments:
- if segment.answer:
- document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
- else:
- document_context_list.append(segment.content)
for segment in sorted_segments:
dataset = Dataset.query.filter_by(
@@ -197,9 +196,9 @@ class KnowledgeRetrievalNode(BaseNode):
'title': document.name
}
if segment.answer:
- source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
+ source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}'
else:
- source['content'] = segment.content
+ source['content'] = segment.get_sign_content()
context_list.append(source)
resource_number += 1
return context_list
diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py
index c390aaf8c9..1e48a10bc7 100644
--- a/api/core/workflow/nodes/llm/entities.py
+++ b/api/core/workflow/nodes/llm/entities.py
@@ -4,6 +4,7 @@ from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
+from core.workflow.entities.variable_entities import VariableSelector
class ModelConfig(BaseModel):
@@ -37,13 +38,31 @@ class VisionConfig(BaseModel):
enabled: bool
configs: Optional[Configs] = None
+class PromptConfig(BaseModel):
+ """
+ Prompt Config.
+ """
+ jinja2_variables: Optional[list[VariableSelector]] = None
+
+class LLMNodeChatModelMessage(ChatModelMessage):
+ """
+ LLM Node Chat Model Message.
+ """
+ jinja2_text: Optional[str] = None
+
+class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
+ """
+ LLM Node Chat Model Prompt Template.
+ """
+ jinja2_text: Optional[str] = None
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
model: ModelConfig
- prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
+ prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
+ prompt_config: Optional[PromptConfig] = None
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig
diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py
index c8b7f279ab..fef09c1385 100644
--- a/api/core/workflow/nodes/llm/llm_node.py
+++ b/api/core/workflow/nodes/llm/llm_node.py
@@ -1,4 +1,6 @@
+import json
from collections.abc import Generator
+from copy import deepcopy
from typing import Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@@ -17,11 +19,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
-from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
-from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
+from core.workflow.nodes.llm.entities import (
+ LLMNodeChatModelMessage,
+ LLMNodeCompletionModelPromptTemplate,
+ LLMNodeData,
+ ModelConfig,
+)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
@@ -39,16 +45,24 @@ class LLMNode(BaseNode):
:param variable_pool: variable pool
:return:
"""
- node_data = self.node_data
- node_data = cast(self._node_data_cls, node_data)
+ node_data = cast(LLMNodeData, deepcopy(self.node_data))
node_inputs = None
process_data = None
try:
+ # init messages template
+ node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
+
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)
+ # fetch jinja2 inputs
+ jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
+
+ # merge inputs
+ inputs.update(jinja_inputs)
+
node_inputs = {}
# fetch files
@@ -183,6 +197,86 @@ class LLMNode(BaseNode):
usage = LLMUsage.empty_usage()
return full_text, usage
+
+ def _transform_chat_messages(self,
+ messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
+ ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
+ """
+ Transform chat messages
+
+ :param messages: chat messages
+ :return:
+ """
+
+ if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
+ if messages.edition_type == 'jinja2':
+ messages.text = messages.jinja2_text
+
+ return messages
+
+ for message in messages:
+ if message.edition_type == 'jinja2':
+ message.text = message.jinja2_text
+
+ return messages
+
+ def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
+ """
+ Fetch jinja inputs
+ :param node_data: node data
+ :param variable_pool: variable pool
+ :return:
+ """
+ variables = {}
+
+ if not node_data.prompt_config:
+ return variables
+
+ for variable_selector in node_data.prompt_config.jinja2_variables or []:
+ variable = variable_selector.variable
+ value = variable_pool.get_variable_value(
+ variable_selector=variable_selector.value_selector
+ )
+
+ def parse_dict(d: dict) -> str:
+ """
+ Parse dict into string
+ """
+ # check if it's a context structure
+ if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
+ return d['content']
+
+ # else, parse the dict
+ try:
+ return json.dumps(d, ensure_ascii=False)
+ except Exception:
+ return str(d)
+
+ if isinstance(value, str):
+ value = value
+ elif isinstance(value, list):
+ result = ''
+ for item in value:
+ if isinstance(item, dict):
+ result += parse_dict(item)
+ elif isinstance(item, str):
+ result += item
+ elif isinstance(item, int | float):
+ result += str(item)
+ else:
+ result += str(item)
+ result += '\n'
+ value = result.strip()
+ elif isinstance(value, dict):
+ value = parse_dict(value)
+ elif isinstance(value, int | float):
+ value = str(value)
+ else:
+ value = str(value)
+
+ variables[variable] = value
+
+ return variables
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
@@ -531,25 +625,25 @@ class LLMNode(BaseNode):
db.session.commit()
@classmethod
- def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
+ def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
- node_data = node_data
- node_data = cast(cls._node_data_cls, node_data)
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
- variable_template_parser = VariableTemplateParser(template=prompt.text)
- variable_selectors.extend(variable_template_parser.extract_variable_selectors())
+ if prompt.edition_type != 'jinja2':
+ variable_template_parser = VariableTemplateParser(template=prompt.text)
+ variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else:
- variable_template_parser = VariableTemplateParser(template=prompt_template.text)
- variable_selectors = variable_template_parser.extract_variable_selectors()
+ if prompt_template.edition_type != 'jinja2':
+ variable_template_parser = VariableTemplateParser(template=prompt_template.text)
+ variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
for variable_selector in variable_selectors:
@@ -571,6 +665,22 @@ class LLMNode(BaseNode):
if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
+ if node_data.prompt_config:
+ enable_jinja = False
+
+ if isinstance(prompt_template, list):
+ for prompt in prompt_template:
+ if prompt.edition_type == 'jinja2':
+ enable_jinja = True
+ break
+ else:
+ if prompt_template.edition_type == 'jinja2':
+ enable_jinja = True
+
+ if enable_jinja:
+ for variable_selector in node_data.prompt_config.jinja2_variables or []:
+ variable_mapping[variable_selector.variable] = variable_selector.value_selector
+
return variable_mapping
@classmethod
@@ -588,7 +698,8 @@ class LLMNode(BaseNode):
"prompts": [
{
"role": "system",
- "text": "You are a helpful AI assistant."
+ "text": "You are a helpful AI assistant.",
+ "edition_type": "basic"
}
]
},
@@ -600,7 +711,8 @@ class LLMNode(BaseNode):
"prompt": {
"text": "Here is the chat histories between human and assistant, inside "
" XML tags.\n\n\n{{"
- "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
+ "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
+ "edition_type": "basic"
},
"stop": ["Human:"]
}
diff --git a/api/core/workflow/nodes/loop/__init__.py b/api/core/workflow/nodes/loop/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py
new file mode 100644
index 0000000000..8a5684551e
--- /dev/null
+++ b/api/core/workflow/nodes/loop/entities.py
@@ -0,0 +1,13 @@
+
+from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
+
+
+class LoopNodeData(BaseIterationNodeData):
+ """
+ Loop Node Data.
+ """
+
+class LoopState(BaseIterationState):
+ """
+ Loop State.
+ """
\ No newline at end of file
diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py
new file mode 100644
index 0000000000..7d53c6f5f2
--- /dev/null
+++ b/api/core/workflow/nodes/loop/loop_node.py
@@ -0,0 +1,20 @@
+from core.workflow.entities.node_entities import NodeRunResult, NodeType
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.nodes.base_node import BaseIterationNode
+from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
+
+
+class LoopNode(BaseIterationNode):
+ """
+ Loop Node.
+ """
+ _node_data_cls = LoopNodeData
+ _node_type = NodeType.LOOP
+
+ def _run(self, variable_pool: VariablePool) -> LoopState:
+ return super()._run(variable_pool)
+
+ def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
+ """
+ Get next iteration start node id based on the graph.
+ """
diff --git a/api/core/workflow/nodes/parameter_extractor/__init__.py b/api/core/workflow/nodes/parameter_extractor/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py
new file mode 100644
index 0000000000..a89a6903ef
--- /dev/null
+++ b/api/core/workflow/nodes/parameter_extractor/entities.py
@@ -0,0 +1,85 @@
+from typing import Any, Literal, Optional
+
+from pydantic import BaseModel, validator
+
+from core.prompt.entities.advanced_prompt_entities import MemoryConfig
+from core.workflow.entities.base_node_data_entities import BaseNodeData
+
+
+class ModelConfig(BaseModel):
+ """
+ Model Config.
+ """
+ provider: str
+ name: str
+ mode: str
+ completion_params: dict[str, Any] = {}
+
+class ParameterConfig(BaseModel):
+ """
+ Parameter Config.
+ """
+ name: str
+ type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]']
+ options: Optional[list[str]]
+ description: str
+ required: bool
+
+ @validator('name', pre=True, always=True)
+ def validate_name(cls, value):
+ if not value:
+ raise ValueError('Parameter name is required')
+ if value in ['__reason', '__is_success']:
+ raise ValueError('Invalid parameter name, __reason and __is_success are reserved')
+ return value
+
+class ParameterExtractorNodeData(BaseNodeData):
+ """
+ Parameter Extractor Node Data.
+ """
+ model: ModelConfig
+ query: list[str]
+ parameters: list[ParameterConfig]
+ instruction: Optional[str]
+ memory: Optional[MemoryConfig]
+ reasoning_mode: Literal['function_call', 'prompt']
+
+ @validator('reasoning_mode', pre=True, always=True)
+ def set_reasoning_mode(cls, v):
+ return v or 'function_call'
+
+ def get_parameter_json_schema(self) -> dict:
+ """
+ Get parameter json schema.
+
+ :return: parameter json schema
+ """
+ parameters = {
+ 'type': 'object',
+ 'properties': {},
+ 'required': []
+ }
+
+ for parameter in self.parameters:
+ parameter_schema = {
+ 'description': parameter.description
+ }
+
+ if parameter.type in ['string', 'select']:
+ parameter_schema['type'] = 'string'
+ elif parameter.type.startswith('array'):
+ parameter_schema['type'] = 'array'
+ nested_type = parameter.type[6:-1]
+ parameter_schema['items'] = {'type': nested_type}
+ else:
+ parameter_schema['type'] = parameter.type
+
+ if parameter.type == 'select':
+ parameter_schema['enum'] = parameter.options
+
+ parameters['properties'][parameter.name] = parameter_schema
+
+ if parameter.required:
+ parameters['required'].append(parameter.name)
+
+ return parameters
\ No newline at end of file
diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
new file mode 100644
index 0000000000..6e7dbd2702
--- /dev/null
+++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
@@ -0,0 +1,711 @@
+import json
+import uuid
+from typing import Optional, cast
+
+from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_manager import ModelInstance
+from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ PromptMessage,
+ PromptMessageRole,
+ PromptMessageTool,
+ ToolPromptMessage,
+ UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.utils.encoders import jsonable_encoder
+from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
+from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
+from core.prompt.simple_prompt_transform import ModelMode
+from core.prompt.utils.prompt_message_util import PromptMessageUtil
+from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.nodes.llm.entities import ModelConfig
+from core.workflow.nodes.llm.llm_node import LLMNode
+from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
+from core.workflow.nodes.parameter_extractor.prompts import (
+ CHAT_EXAMPLE,
+ CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
+ COMPLETION_GENERATE_JSON_PROMPT,
+ FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
+ FUNCTION_CALLING_EXTRACTOR_NAME,
+ FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT,
+ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
+)
+from core.workflow.utils.variable_template_parser import VariableTemplateParser
+from extensions.ext_database import db
+from models.workflow import WorkflowNodeExecutionStatus
+
+
+class ParameterExtractorNode(LLMNode):
+ """
+ Parameter Extractor Node.
+ """
+ _node_data_cls = ParameterExtractorNodeData
+ _node_type = NodeType.PARAMETER_EXTRACTOR
+
+ _model_instance: Optional[ModelInstance] = None
+ _model_config: Optional[ModelConfigWithCredentialsEntity] = None
+
+ @classmethod
+ def get_default_config(cls, filters: Optional[dict] = None) -> dict:
+ return {
+ "model": {
+ "prompt_templates": {
+ "completion_model": {
+ "conversation_histories_role": {
+ "user_prefix": "Human",
+ "assistant_prefix": "Assistant"
+ },
+ "stop": ["Human:"]
+ }
+ }
+ }
+ }
+
+
+ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+ """
+ Run the node.
+ """
+
+ node_data = cast(ParameterExtractorNodeData, self.node_data)
+ query = variable_pool.get_variable_value(node_data.query)
+ if not query:
+ raise ValueError("Query not found")
+
+ inputs={
+ 'query': query,
+ 'parameters': jsonable_encoder(node_data.parameters),
+ 'instruction': jsonable_encoder(node_data.instruction),
+ }
+
+ model_instance, model_config = self._fetch_model_config(node_data.model)
+ if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
+ raise ValueError("Model is not a Large Language Model")
+
+ llm_model = model_instance.model_type_instance
+ model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
+ if not model_schema:
+ raise ValueError("Model schema not found")
+
+ # fetch memory
+ memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
+
+ if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \
+ and node_data.reasoning_mode == 'function_call':
+ # use function call
+ prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
+ node_data, query, variable_pool, model_config, memory
+ )
+ else:
+ # use prompt engineering
+ prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, memory)
+ prompt_message_tools = []
+
+ process_data = {
+ 'model_mode': model_config.mode,
+ 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
+ model_mode=model_config.mode,
+ prompt_messages=prompt_messages
+ ),
+ 'usage': None,
+ 'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
+ 'tool_call': None,
+ }
+
+ try:
+ text, usage, tool_call = self._invoke_llm(
+ node_data_model=node_data.model,
+ model_instance=model_instance,
+ prompt_messages=prompt_messages,
+ tools=prompt_message_tools,
+ stop=model_config.stop,
+ )
+ process_data['usage'] = jsonable_encoder(usage)
+ process_data['tool_call'] = jsonable_encoder(tool_call)
+ process_data['llm_text'] = text
+ except Exception as e:
+ return NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ inputs=inputs,
+ process_data={},
+ outputs={
+ '__is_success': 0,
+ '__reason': str(e)
+ },
+ error=str(e),
+ metadata={}
+ )
+
+ error = None
+
+ if tool_call:
+ result = self._extract_json_from_tool_call(tool_call)
+ else:
+ result = self._extract_complete_json_response(text)
+ if not result:
+ result = self._generate_default_result(node_data)
+ error = "Failed to extract result from function call or text response, using empty result."
+
+ try:
+ result = self._validate_result(node_data, result)
+ except Exception as e:
+ error = str(e)
+
+ # transform result into standard format
+ result = self._transform_result(node_data, result)
+
+ return NodeRunResult(
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
+ inputs=inputs,
+ process_data=process_data,
+ outputs={
+ '__is_success': 1 if not error else 0,
+ '__reason': error,
+ **result
+ },
+ metadata={
+ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
+ NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
+ NodeRunMetadataKey.CURRENCY: usage.currency
+ }
+ )
+
+ def _invoke_llm(self, node_data_model: ModelConfig,
+ model_instance: ModelInstance,
+ prompt_messages: list[PromptMessage],
+ tools: list[PromptMessageTool],
+ stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
+ """
+ Invoke large language model
+ :param node_data_model: node data model
+ :param model_instance: model instance
+ :param prompt_messages: prompt messages
+ :param stop: stop
+ :return:
+ """
+ db.session.close()
+
+ invoke_result = model_instance.invoke_llm(
+ prompt_messages=prompt_messages,
+ model_parameters=node_data_model.completion_params,
+ tools=tools,
+ stop=stop,
+ stream=False,
+ user=self.user_id,
+ )
+
+ # handle invoke result
+ if not isinstance(invoke_result, LLMResult):
+ raise ValueError(f"Invalid invoke result: {invoke_result}")
+
+ text = invoke_result.message.content
+ usage = invoke_result.usage
+ tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
+
+ # deduct quota
+ self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
+
+ return text, usage, tool_call
+
+ def _generate_function_call_prompt(self,
+ node_data: ParameterExtractorNodeData,
+ query: str,
+ variable_pool: VariablePool,
+ model_config: ModelConfigWithCredentialsEntity,
+ memory: Optional[TokenBufferMemory],
+ ) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
+ """
+ Generate function call prompt.
+ """
+ query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(node_data.get_parameter_json_schema()))
+
+ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
+ rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
+ prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, rest_token)
+ prompt_messages = prompt_transform.get_prompt(
+ prompt_template=prompt_template,
+ inputs={},
+ query='',
+ files=[],
+ context='',
+ memory_config=node_data.memory,
+ memory=None,
+ model_config=model_config
+ )
+
+ # find last user message
+ last_user_message_idx = -1
+ for i, prompt_message in enumerate(prompt_messages):
+ if prompt_message.role == PromptMessageRole.USER:
+ last_user_message_idx = i
+
+ # add function call messages before last user message
+ example_messages = []
+ for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE:
+ id = uuid.uuid4().hex
+ example_messages.extend([
+ UserPromptMessage(content=example['user']['query']),
+ AssistantPromptMessage(
+ content=example['assistant']['text'],
+ tool_calls=[
+ AssistantPromptMessage.ToolCall(
+ id=id,
+ type='function',
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+ name=example['assistant']['function_call']['name'],
+ arguments=json.dumps(example['assistant']['function_call']['parameters']
+ )
+ ))
+ ]
+ ),
+ ToolPromptMessage(
+ content='Great! You have called the function with the correct parameters.',
+ tool_call_id=id
+ ),
+ AssistantPromptMessage(
+ content='I have extracted the parameters, let\'s move on.',
+ )
+ ])
+
+ prompt_messages = prompt_messages[:last_user_message_idx] + \
+ example_messages + prompt_messages[last_user_message_idx:]
+
+ # generate tool
+ tool = PromptMessageTool(
+ name=FUNCTION_CALLING_EXTRACTOR_NAME,
+ description='Extract parameters from the natural language text',
+ parameters=node_data.get_parameter_json_schema(),
+ )
+
+ return prompt_messages, [tool]
+
+ def _generate_prompt_engineering_prompt(self,
+ data: ParameterExtractorNodeData,
+ query: str,
+ variable_pool: VariablePool,
+ model_config: ModelConfigWithCredentialsEntity,
+ memory: Optional[TokenBufferMemory],
+ ) -> list[PromptMessage]:
+ """
+ Generate prompt engineering prompt.
+ """
+ model_mode = ModelMode.value_of(data.model.mode)
+
+ if model_mode == ModelMode.COMPLETION:
+ return self._generate_prompt_engineering_completion_prompt(
+ data, query, variable_pool, model_config, memory
+ )
+ elif model_mode == ModelMode.CHAT:
+ return self._generate_prompt_engineering_chat_prompt(
+ data, query, variable_pool, model_config, memory
+ )
+ else:
+ raise ValueError(f"Invalid model mode: {model_mode}")
+
+ def _generate_prompt_engineering_completion_prompt(self,
+ node_data: ParameterExtractorNodeData,
+ query: str,
+ variable_pool: VariablePool,
+ model_config: ModelConfigWithCredentialsEntity,
+ memory: Optional[TokenBufferMemory],
+ ) -> list[PromptMessage]:
+ """
+ Generate completion prompt.
+ """
+ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
+ rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
+ prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, rest_token)
+ prompt_messages = prompt_transform.get_prompt(
+ prompt_template=prompt_template,
+ inputs={
+ 'structure': json.dumps(node_data.get_parameter_json_schema())
+ },
+ query='',
+ files=[],
+ context='',
+ memory_config=node_data.memory,
+ memory=memory,
+ model_config=model_config
+ )
+
+ return prompt_messages
+
+ def _generate_prompt_engineering_chat_prompt(self,
+ node_data: ParameterExtractorNodeData,
+ query: str,
+ variable_pool: VariablePool,
+ model_config: ModelConfigWithCredentialsEntity,
+ memory: Optional[TokenBufferMemory],
+ ) -> list[PromptMessage]:
+ """
+ Generate chat prompt.
+ """
+ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
+ rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
+ prompt_template = self._get_prompt_engineering_prompt_template(
+ node_data,
+ CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
+ structure=json.dumps(node_data.get_parameter_json_schema()),
+ text=query
+ ),
+ variable_pool, memory, rest_token
+ )
+
+ prompt_messages = prompt_transform.get_prompt(
+ prompt_template=prompt_template,
+ inputs={},
+ query='',
+ files=[],
+ context='',
+ memory_config=node_data.memory,
+ memory=memory,
+ model_config=model_config
+ )
+
+ # find last user message
+ last_user_message_idx = -1
+ for i, prompt_message in enumerate(prompt_messages):
+ if prompt_message.role == PromptMessageRole.USER:
+ last_user_message_idx = i
+
+ # add example messages before last user message
+ example_messages = []
+ for example in CHAT_EXAMPLE:
+ example_messages.extend([
+ UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
+ structure=json.dumps(example['user']['json']),
+ text=example['user']['query'],
+ )),
+ AssistantPromptMessage(
+ content=json.dumps(example['assistant']['json']),
+ )
+ ])
+
+ prompt_messages = prompt_messages[:last_user_message_idx] + \
+ example_messages + prompt_messages[last_user_message_idx:]
+
+ return prompt_messages
+
+ def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
+ """
+ Validate result.
+ """
+ if len(data.parameters) != len(result):
+ raise ValueError("Invalid number of parameters")
+
+ for parameter in data.parameters:
+ if parameter.required and parameter.name not in result:
+ raise ValueError(f"Parameter {parameter.name} is required")
+
+ if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options:
+ raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
+
+ if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float):
+ raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
+
+ if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool):
+ raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
+
+ if parameter.type == 'string' and not isinstance(result.get(parameter.name), str):
+ raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
+
+ if parameter.type.startswith('array'):
+ if not isinstance(result.get(parameter.name), list):
+ raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
+ nested_type = parameter.type[6:-1]
+ for item in result.get(parameter.name):
+ if nested_type == 'number' and not isinstance(item, int | float):
+ raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
+ if nested_type == 'string' and not isinstance(item, str):
+ raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
+ if nested_type == 'object' and not isinstance(item, dict):
+ raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
+ return result
+
+ def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
+ """
+ Transform result into standard format.
+ """
+ transformed_result = {}
+ for parameter in data.parameters:
+ if parameter.name in result:
+ # transform value
+ if parameter.type == 'number':
+ if isinstance(result[parameter.name], int | float):
+ transformed_result[parameter.name] = result[parameter.name]
+ elif isinstance(result[parameter.name], str):
+ try:
+ if '.' in result[parameter.name]:
+ result[parameter.name] = float(result[parameter.name])
+ else:
+ result[parameter.name] = int(result[parameter.name])
+ except ValueError:
+ pass
+ else:
+ pass
+ # TODO: bool is not supported in the current version
+ # elif parameter.type == 'bool':
+ # if isinstance(result[parameter.name], bool):
+ # transformed_result[parameter.name] = bool(result[parameter.name])
+ # elif isinstance(result[parameter.name], str):
+ # if result[parameter.name].lower() in ['true', 'false']:
+ # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
+ # elif isinstance(result[parameter.name], int):
+ # transformed_result[parameter.name] = bool(result[parameter.name])
+ elif parameter.type in ['string', 'select']:
+ if isinstance(result[parameter.name], str):
+ transformed_result[parameter.name] = result[parameter.name]
+ elif parameter.type.startswith('array'):
+ if isinstance(result[parameter.name], list):
+ nested_type = parameter.type[6:-1]
+ transformed_result[parameter.name] = []
+ for item in result[parameter.name]:
+ if nested_type == 'number':
+ if isinstance(item, int | float):
+ transformed_result[parameter.name].append(item)
+ elif isinstance(item, str):
+ try:
+ if '.' in item:
+ transformed_result[parameter.name].append(float(item))
+ else:
+ transformed_result[parameter.name].append(int(item))
+ except ValueError:
+ pass
+ elif nested_type == 'string':
+ if isinstance(item, str):
+ transformed_result[parameter.name].append(item)
+ elif nested_type == 'object':
+ if isinstance(item, dict):
+ transformed_result[parameter.name].append(item)
+
+ if parameter.name not in transformed_result:
+ if parameter.type == 'number':
+ transformed_result[parameter.name] = 0
+ elif parameter.type == 'bool':
+ transformed_result[parameter.name] = False
+ elif parameter.type in ['string', 'select']:
+ transformed_result[parameter.name] = ''
+ elif parameter.type.startswith('array'):
+ transformed_result[parameter.name] = []
+
+ return transformed_result
+
+ def _extract_complete_json_response(self, result: str) -> Optional[dict]:
+ """
+ Extract complete json response.
+ """
+ def extract_json(text):
+ """
+ From a given JSON started from '{' or '[' extract the complete JSON object.
+ """
+ stack = []
+ for i, c in enumerate(text):
+ if c == '{' or c == '[':
+ stack.append(c)
+ elif c == '}' or c == ']':
+ # check if stack is empty
+ if not stack:
+ return text[:i]
+ # check if the last element in stack is matching
+ if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['):
+ stack.pop()
+ if not stack:
+ return text[:i+1]
+ else:
+ return text[:i]
+ return None
+
+ # extract json from the text
+ for idx in range(len(result)):
+ if result[idx] == '{' or result[idx] == '[':
+ json_str = extract_json(result[idx:])
+ if json_str:
+ try:
+ return json.loads(json_str)
+ except Exception:
+ pass
+
+ def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
+ """
+ Extract json from tool call.
+ """
+ if not tool_call or not tool_call.function.arguments:
+ return None
+
+ return json.loads(tool_call.function.arguments)
+
+ def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
+ """
+ Generate default result.
+ """
+ result = {}
+ for parameter in data.parameters:
+ if parameter.type == 'number':
+ result[parameter.name] = 0
+ elif parameter.type == 'bool':
+ result[parameter.name] = False
+ elif parameter.type in ['string', 'select']:
+ result[parameter.name] = ''
+
+ return result
+
+ def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
+ """
+ Render instruction.
+ """
+ variable_template_parser = VariableTemplateParser(instruction)
+ inputs = {}
+ for selector in variable_template_parser.extract_variable_selectors():
+ inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
+
+ return variable_template_parser.format(inputs)
+
+ def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
+ variable_pool: VariablePool,
+ memory: Optional[TokenBufferMemory],
+ max_token_limit: int = 2000) \
+ -> list[ChatModelMessage]:
+ model_mode = ModelMode.value_of(node_data.model.mode)
+ input_text = query
+ memory_str = ''
+ instruction = self._render_instruction(node_data.instruction or '', variable_pool)
+
+ if memory:
+ memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
+ message_limit=node_data.memory.window.size)
+ if model_mode == ModelMode.CHAT:
+ system_prompt_messages = ChatModelMessage(
+ role=PromptMessageRole.SYSTEM,
+ text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
+ )
+ user_prompt_message = ChatModelMessage(
+ role=PromptMessageRole.USER,
+ text=input_text
+ )
+ return [system_prompt_messages, user_prompt_message]
+ else:
+ raise ValueError(f"Model mode {model_mode} not support.")
+
+ def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
+ variable_pool: VariablePool,
+ memory: Optional[TokenBufferMemory],
+ max_token_limit: int = 2000) \
+ -> list[ChatModelMessage]:
+
+ model_mode = ModelMode.value_of(node_data.model.mode)
+ input_text = query
+ memory_str = ''
+ instruction = self._render_instruction(node_data.instruction or '', variable_pool)
+
+ if memory:
+ memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
+ message_limit=node_data.memory.window.size)
+ if model_mode == ModelMode.CHAT:
+ system_prompt_messages = ChatModelMessage(
+ role=PromptMessageRole.SYSTEM,
+ text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
+ )
+ user_prompt_message = ChatModelMessage(
+ role=PromptMessageRole.USER,
+ text=input_text
+ )
+ return [system_prompt_messages, user_prompt_message]
+ elif model_mode == ModelMode.COMPLETION:
+ return CompletionModelPromptTemplate(
+ text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str,
+ text=input_text,
+ instruction=instruction)
+ .replace('{γγγ', '')
+ .replace('}γγγ', '')
+ )
+ else:
+ raise ValueError(f"Model mode {model_mode} not support.")
+
+ def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str,
+ variable_pool: VariablePool,
+ model_config: ModelConfigWithCredentialsEntity,
+ context: Optional[str]) -> int:
+ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
+
+ model_instance, model_config = self._fetch_model_config(node_data.model)
+ if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
+ raise ValueError("Model is not a Large Language Model")
+
+ llm_model = model_instance.model_type_instance
+ model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
+ if not model_schema:
+ raise ValueError("Model schema not found")
+
+ if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]):
+ prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
+ else:
+ prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
+
+ prompt_messages = prompt_transform.get_prompt(
+ prompt_template=prompt_template,
+ inputs={},
+ query='',
+ files=[],
+ context=context,
+ memory_config=node_data.memory,
+ memory=None,
+ model_config=model_config
+ )
+ rest_tokens = 2000
+
+ model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
+ if model_context_tokens:
+ model_type_instance = model_config.provider_model_bundle.model_type_instance
+ model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+ curr_message_tokens = model_type_instance.get_num_tokens(
+ model_config.model,
+ model_config.credentials,
+ prompt_messages
+ ) + 1000 # add 1000 to ensure tool call messages
+
+ max_tokens = 0
+ for parameter_rule in model_config.model_schema.parameter_rules:
+ if (parameter_rule.name == 'max_tokens'
+ or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+ max_tokens = (model_config.parameters.get(parameter_rule.name)
+ or model_config.parameters.get(parameter_rule.use_template)) or 0
+
+ rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
+ rest_tokens = max(rest_tokens, 0)
+
+ return rest_tokens
+
+ def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
+ """
+ Fetch model config.
+ """
+ if not self._model_instance or not self._model_config:
+ self._model_instance, self._model_config = super()._fetch_model_config(node_data_model)
+
+ return self._model_instance, self._model_config
+
+ @classmethod
+ def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[str, list[str]]:
+ """
+ Extract variable selector to variable mapping
+ :param node_data: node data
+ :return:
+ """
+ node_data = node_data
+
+ variable_mapping = {
+ 'query': node_data.query
+ }
+
+ if node_data.instruction:
+ variable_template_parser = VariableTemplateParser(template=node_data.instruction)
+ for selector in variable_template_parser.extract_variable_selectors():
+ variable_mapping[selector.variable] = selector.value_selector
+
+ return variable_mapping
\ No newline at end of file
diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py
new file mode 100644
index 0000000000..499c58d505
--- /dev/null
+++ b/api/core/workflow/nodes/parameter_extractor/prompts.py
@@ -0,0 +1,206 @@
+FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters'
+
+FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
+### Task
+Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria.
+### Memory
+Here is the chat history between the human and assistant, provided within tags:
+
+\x7bhistories\x7d
+
+### Instructions:
+Some additional information is provided below. Always adhere to these instructions as closely as possible:
+
+\x7binstruction\x7d
+
+Steps:
+1. Review the chat history provided within the tags.
+2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
+3. Generate a well-formatted output using the defined functions and arguments.
+4. Use the `extract_parameter` function to create structured outputs with appropriate parameters.
+5. Do not include any XML tags in your output.
+### Example
+To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples.
+### Final Output
+Produce well-formatted function calls in json without XML tags, as shown in the example.
+"""
+
+FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside XML tags.
+
+\x7bcontent\x7d
+
+
+
+\x7bstructure\x7d
+
+"""
+
+FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{
+ 'user': {
+ 'query': 'What is the weather today in SF?',
+ 'function': {
+ 'name': FUNCTION_CALLING_EXTRACTOR_NAME,
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'location': {
+ 'type': 'string',
+ 'description': 'The location to get the weather information',
+ 'required': True
+ },
+ },
+ 'required': ['location']
+ }
+ }
+ },
+ 'assistant': {
+ 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.',
+ 'function_call' : {
+ 'name': FUNCTION_CALLING_EXTRACTOR_NAME,
+ 'parameters': {
+ 'location': 'San Francisco'
+ }
+ }
+ }
+}, {
+ 'user': {
+ 'query': 'I want to eat some apple pie.',
+ 'function': {
+ 'name': FUNCTION_CALLING_EXTRACTOR_NAME,
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'food': {
+ 'type': 'string',
+ 'description': 'The food to eat',
+ 'required': True
+ }
+ },
+ 'required': ['food']
+ }
+ }
+ },
+ 'assistant': {
+ 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.',
+ 'function_call' : {
+ 'name': FUNCTION_CALLING_EXTRACTOR_NAME,
+ 'parameters': {
+ 'food': 'apple pie'
+ }
+ }
+ }
+}]
+
+COMPLETION_GENERATE_JSON_PROMPT = """### Instructions:
+Some extra information are provided below, I should always follow the instructions as possible as I can.
+
+{instruction}
+
+
+### Extract parameter Workflow
+I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted.
+
+{{ structure }}
+
+
+Step 1: Carefully read the input and understand the structure of the expected output.
+Step 2: Extract relevant parameters from the provided text based on the name and description of object.
+Step 3: Structure the extracted parameters to JSON object as specified in .
+Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
+
+### Memory
+Here is the chat histories between human and assistant, inside XML tags.
+
+{histories}
+
+
+### Structure
+Here is the structure of the expected output, I should always follow the output structure.
+{{γγγ
+ 'properties1': 'relevant text extracted from input',
+ 'properties2': 'relevant text extracted from input',
+}}γγγ
+
+### Input Text
+Inside XML tags, there is a text that I should extract parameters and convert to a JSON object.
+
+{text}
+
+
+### Answer
+I should always output a valid JSON object. Output nothing other than the JSON object.
+```JSON
+"""
+
+CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object.
+The structure of the JSON object you can found in the instructions.
+
+### Memory
+Here is the chat histories between human and assistant, inside XML tags.
+
+{histories}
+
+
+### Instructions:
+Some extra information are provided below, you should always follow the instructions as possible as you can.
+
+{{instructions}}
+
+"""
+
+CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure
+Here is the structure of the JSON object, you should always follow the structure.
+
+{structure}
+
+
+### Text to be converted to JSON
+Inside XML tags, there is a text that you should convert to a JSON object.
+
+{text}
+
+"""
+
+CHAT_EXAMPLE = [{
+ 'user': {
+ 'query': 'What is the weather today in SF?',
+ 'json': {
+ 'type': 'object',
+ 'properties': {
+ 'location': {
+ 'type': 'string',
+ 'description': 'The location to get the weather information',
+ 'required': True
+ }
+ },
+ 'required': ['location']
+ }
+ },
+ 'assistant': {
+ 'text': 'I need to output a valid JSON object.',
+ 'json': {
+ 'location': 'San Francisco'
+ }
+ }
+}, {
+ 'user': {
+ 'query': 'I want to eat some apple pie.',
+ 'json': {
+ 'type': 'object',
+ 'properties': {
+ 'food': {
+ 'type': 'string',
+ 'description': 'The food to eat',
+ 'required': True
+ }
+ },
+ 'required': ['food']
+ }
+ },
+ 'assistant': {
+ 'text': 'I need to output a valid JSON object.',
+ 'json': {
+ 'result': 'apple pie'
+ }
+ }
+}]
\ No newline at end of file
diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
index 9ec0df721c..76f3dec836 100644
--- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py
+++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
@@ -1,16 +1,18 @@
+import json
import logging
from typing import Optional, Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelPropertyKey
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
+from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
@@ -25,6 +27,7 @@ from core.workflow.nodes.question_classifier.template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_2,
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
+from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.workflow import WorkflowNodeExecutionStatus
@@ -46,6 +49,9 @@ class QuestionClassifierNode(LLMNode):
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
+ # fetch instruction
+ instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else ''
+ node_data.instruction = instruction
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt(
node_data=node_data,
@@ -62,13 +68,20 @@ class QuestionClassifierNode(LLMNode):
prompt_messages=prompt_messages,
stop=stop
)
- categories = [_class.name for _class in node_data.classes]
+ category_name = node_data.classes[0].name
+ category_id = node_data.classes[0].id
try:
result_text_json = parse_and_check_json_markdown(result_text, [])
- #result_text_json = json.loads(result_text.strip('```JSON\n'))
- categories_result = result_text_json.get('categories', [])
- if categories_result:
- categories = categories_result
+ # result_text_json = json.loads(result_text.strip('```JSON\n'))
+ if 'category_name' in result_text_json and 'category_id' in result_text_json:
+ category_id_result = result_text_json['category_id']
+ classes = node_data.classes
+ classes_map = {class_.id: class_.name for class_ in classes}
+ category_ids = [_class.id for _class in classes]
+ if category_id_result in category_ids:
+ category_name = classes_map[category_id_result]
+ category_id = category_id_result
+
except Exception:
logging.error(f"Failed to parse result text: {result_text}")
try:
@@ -81,17 +94,15 @@ class QuestionClassifierNode(LLMNode):
'usage': jsonable_encoder(usage),
}
outputs = {
- 'class_name': categories[0] if categories else ''
+ 'class_name': category_name
}
- classes = node_data.classes
- classes_map = {class_.name: class_.id for class_ in classes}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=process_data,
outputs=outputs,
- edge_source_handle=classes_map.get(categories[0], None),
+ edge_source_handle=category_id,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
@@ -116,6 +127,12 @@ class QuestionClassifierNode(LLMNode):
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
variable_mapping = {'query': node_data.query_variable_selector}
+ variable_selectors = []
+ if node_data.instruction:
+ variable_template_parser = VariableTemplateParser(template=node_data.instruction)
+ variable_selectors.extend(variable_template_parser.extract_variable_selectors())
+ for variable_selector in variable_selectors:
+ variable_mapping[variable_selector.variable] = variable_selector.value_selector
return variable_mapping
@classmethod
@@ -183,12 +200,12 @@ class QuestionClassifierNode(LLMNode):
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
- model_type_instance = model_config.provider_model_bundle.model_type_instance
- model_type_instance = cast(LargeLanguageModel, model_type_instance)
+ model_instance = ModelInstance(
+ provider_model_bundle=model_config.provider_model_bundle,
+ model=model_config.model
+ )
- curr_message_tokens = model_type_instance.get_num_tokens(
- model_config.model,
- model_config.credentials,
+ curr_message_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
@@ -210,8 +227,13 @@ class QuestionClassifierNode(LLMNode):
-> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
model_mode = ModelMode.value_of(node_data.model.mode)
classes = node_data.classes
- class_names = [class_.name for class_ in classes]
- class_names_str = ','.join(f'"{name}"' for name in class_names)
+ categories = []
+ for class_ in classes:
+ category = {
+ 'category_id': class_.id,
+ 'category_name': class_.name
+ }
+ categories.append(category)
instruction = node_data.instruction if node_data.instruction else ''
input_text = query
memory_str = ''
@@ -248,7 +270,7 @@ class QuestionClassifierNode(LLMNode):
user_prompt_message_3 = ChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text,
- categories=class_names_str,
+ categories=json.dumps(categories, ensure_ascii=False),
classification_instructions=instruction)
)
prompt_messages.append(user_prompt_message_3)
@@ -257,9 +279,31 @@ class QuestionClassifierNode(LLMNode):
return CompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
input_text=input_text,
- categories=class_names_str,
- classification_instructions=instruction)
+ categories=json.dumps(categories),
+ classification_instructions=instruction,
+ ensure_ascii=False)
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
+
+ def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
+ inputs = {}
+
+ variable_selectors = []
+ variable_template_parser = VariableTemplateParser(template=instruction)
+ variable_selectors.extend(variable_template_parser.extract_variable_selectors())
+ for variable_selector in variable_selectors:
+ variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
+ if variable_value is None:
+ raise ValueError(f'Variable {variable_selector.variable} not found')
+
+ inputs[variable_selector.variable] = variable_value
+
+ prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True)
+ prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+
+ instruction = prompt_template.format(
+ prompt_inputs
+ )
+ return instruction
diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py
index 5bef0250e3..23bd05a809 100644
--- a/api/core/workflow/nodes/question_classifier/template_prompts.py
+++ b/api/core/workflow/nodes/question_classifier/template_prompts.py
@@ -6,7 +6,7 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
### Task
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output.Additionally, you need to extract the key words from the text that are related to the classification.
### Format
- The input text is in the variable text_field.Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination.Classification instructions may be included to improve the classification accuracy.
+ The input text is in the variable text_field.Categories are specified as a category list with two filed category_id and category_name in the variable categories .Classification instructions may be included to improve the classification accuracy.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Memory
@@ -18,33 +18,35 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
QUESTION_CLASSIFIER_USER_PROMPT_1 = """
{ "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],
- "categories": ["Customer Service", "Satisfaction", "Sales", "Product"],
+ "categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}],
"classification_instructions": ["classify the text based on the feedback provided by customer"]}
"""
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
```json
{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],
- "categories": ["Customer Service"]}
+ "category_id": "f5660049-284f-41a7-b301-fd24176a711c",
+ "category_name": "Customer Service"}
```
"""
QUESTION_CLASSIFIER_USER_PROMPT_2 = """
{"input_text": ["bad service, slow to bring the food"],
- "categories": ["Food Quality", "Experience", "Price" ],
+ "categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}],
"classification_instructions": []}
"""
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
```json
{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],
- "categories": ["Experience"]}
+ "category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f",
+ "category_name": "Experience"}
```
"""
QUESTION_CLASSIFIER_USER_PROMPT_3 = """
'{{"input_text": ["{input_text}"],',
- '"categories": ["{categories}" ], ',
+ '"categories": {categories}, ',
'"classification_instructions": ["{classification_instructions}"]}}'
"""
@@ -54,16 +56,16 @@ You are a text classification engine that analyzes text data and assigns categor
### Task
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
### Format
-The input text is in the variable text_field. Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination. Classification instructions may be included to improve the classification accuracy.
+The input text is in the variable text_field. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside XML tags.
-User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],"categories": ["Customer Service, Satisfaction, Sales, Product"], "classification_instructions": ["classify the text based on the feedback provided by customer"]}}
-Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"categories": ["Customer Service"]}}
-User:{{"input_text": ["bad service, slow to bring the food"],"categories": ["Food Quality, Experience, Price" ], "classification_instructions": []}}
-Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"categories": ["Customer Service"]}}{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"categories": ["Experience""]}}
+User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"}},{{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"}},{{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"}},{{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}}], "classification_instructions": ["classify the text based on the feedback provided by customer"]}}
+Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}}
+User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}}
+Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}}
### Memory
Here is the chat histories between human and assistant, inside XML tags.
@@ -71,6 +73,6 @@ Here is the chat histories between human and assistant, inside
### User Input
-{{"input_text" : ["{input_text}"], "categories" : ["{categories}"],"classification_instruction" : ["{classification_instructions}"]}}
+{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}}
### Assistant Output
-"""
\ No newline at end of file
+"""
diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py
index 9e5cc0c889..2c4a2257f5 100644
--- a/api/core/workflow/nodes/template_transform/template_transform_node.py
+++ b/api/core/workflow/nodes/template_transform/template_transform_node.py
@@ -1,7 +1,7 @@
import os
from typing import Optional, cast
-from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
+from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
@@ -53,7 +53,7 @@ class TemplateTransformNode(BaseNode):
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
- language='jinja2',
+ language=CodeLanguage.JINJA2,
code=node_data.template,
inputs=variables
)
diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py
index 97fbe8a999..98b28ac4f1 100644
--- a/api/core/workflow/nodes/tool/entities.py
+++ b/api/core/workflow/nodes/tool/entities.py
@@ -7,7 +7,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
class ToolEntity(BaseModel):
provider_id: str
- provider_type: Literal['builtin', 'api']
+ provider_type: Literal['builtin', 'api', 'workflow']
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py
index d183dbe17b..2a472fc8d2 100644
--- a/api/core/workflow/nodes/tool/tool_node.py
+++ b/api/core/workflow/nodes/tool/tool_node.py
@@ -1,13 +1,14 @@
from os import path
-from typing import cast
+from typing import Optional, cast
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
-from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
+from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
-from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
+from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
@@ -35,21 +36,23 @@ class ToolNode(BaseNode):
'provider_id': node_data.provider_id
}
- # get parameters
- parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime
try:
- self.app_id
- tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data)
+ tool_runtime = ToolManager.get_workflow_tool_runtime(
+ self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
+ )
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
- inputs=parameters,
+ inputs={},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to get tool runtime: {str(e)}'
)
+
+ # get parameters
+ parameters = self._generate_parameters(variable_pool, node_data, tool_runtime)
try:
messages = ToolEngine.workflow_invoke(
@@ -57,7 +60,8 @@ class ToolNode(BaseNode):
tool_parameters=parameters,
user_id=self.user_id,
workflow_id=self.workflow_id,
- workflow_tool_callback=DifyWorkflowCallbackHandler()
+ workflow_tool_callback=DifyWorkflowCallbackHandler(),
+ workflow_call_depth=self.workflow_call_depth,
)
except Exception as e:
return NodeRunResult(
@@ -84,19 +88,32 @@ class ToolNode(BaseNode):
inputs=parameters
)
- def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
+ def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData, tool_runtime: Tool) -> dict:
"""
Generate parameters
"""
+ tool_parameters = tool_runtime.get_all_runtime_parameters()
+
+ def fetch_parameter(name: str) -> Optional[ToolParameter]:
+ return next((parameter for parameter in tool_parameters if parameter.name == name), None)
+
result = {}
for parameter_name in node_data.tool_parameters:
- input = node_data.tool_parameters[parameter_name]
- if input.type == 'mixed':
- result[parameter_name] = self._format_variable_template(input.value, variable_pool)
- elif input.type == 'variable':
- result[parameter_name] = variable_pool.get_variable_value(input.value)
- elif input.type == 'constant':
- result[parameter_name] = input.value
+ parameter = fetch_parameter(parameter_name)
+ if not parameter:
+ continue
+ if parameter.type == ToolParameter.ToolParameterType.FILE:
+ result[parameter_name] = [
+ v.to_dict() for v in self._fetch_files(variable_pool)
+ ]
+ else:
+ input = node_data.tool_parameters[parameter_name]
+ if input.type == 'mixed':
+ result[parameter_name] = self._format_variable_template(input.value, variable_pool)
+ elif input.type == 'variable':
+ result[parameter_name] = variable_pool.get_variable_value(input.value)
+ elif input.type == 'constant':
+ result[parameter_name] = input.value
return result
@@ -110,6 +127,13 @@ class ToolNode(BaseNode):
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
return template_parser.format(inputs)
+
+ def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
+ files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
+ if not files:
+ return []
+
+ return files
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
"""
@@ -174,11 +198,12 @@ class ToolNode(BaseNode):
"""
Extract tool response text
"""
- return ''.join([
- f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else
- f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
+ return '\n'.join([
+ f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else
+ f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else ''
for message in tool_response
])
+
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
diff --git a/api/core/workflow/nodes/variable_aggregator/__init__.py b/api/core/workflow/nodes/variable_aggregator/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py
new file mode 100644
index 0000000000..d38e934451
--- /dev/null
+++ b/api/core/workflow/nodes/variable_aggregator/entities.py
@@ -0,0 +1,33 @@
+
+
+from typing import Literal, Optional
+
+from pydantic import BaseModel
+
+from core.workflow.entities.base_node_data_entities import BaseNodeData
+
+
+class AdvancedSettings(BaseModel):
+ """
+ Advanced setting.
+ """
+ group_enabled: bool
+
+ class Group(BaseModel):
+ """
+ Group.
+ """
+ output_type: Literal['string', 'number', 'array', 'object']
+ variables: list[list[str]]
+ group_name: str
+
+ groups: list[Group]
+
+class VariableAssignerNodeData(BaseNodeData):
+ """
+ Knowledge retrieval Node Data.
+ """
+ type: str = 'variable-assigner'
+ output_type: str
+ variables: list[list[str]]
+ advanced_settings: Optional[AdvancedSettings]
\ No newline at end of file
diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
new file mode 100644
index 0000000000..63ce790625
--- /dev/null
+++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
@@ -0,0 +1,54 @@
+from typing import cast
+
+from core.workflow.entities.base_node_data_entities import BaseNodeData
+from core.workflow.entities.node_entities import NodeRunResult, NodeType
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.nodes.base_node import BaseNode
+from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
+from models.workflow import WorkflowNodeExecutionStatus
+
+
+class VariableAggregatorNode(BaseNode):
+ _node_data_cls = VariableAssignerNodeData
+ _node_type = NodeType.VARIABLE_AGGREGATOR
+
+ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+ node_data = cast(VariableAssignerNodeData, self.node_data)
+ # Get variables
+ outputs = {}
+ inputs = {}
+
+ if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
+ for variable in node_data.variables:
+ value = variable_pool.get_variable_value(variable)
+
+ if value is not None:
+ outputs = {
+ "output": value
+ }
+
+ inputs = {
+ '.'.join(variable[1:]): value
+ }
+ break
+ else:
+ for group in node_data.advanced_settings.groups:
+ for variable in group.variables:
+ value = variable_pool.get_variable_value(variable)
+
+ if value is not None:
+ outputs[group.group_name] = {
+ 'output': value
+ }
+ inputs['.'.join(variable[1:])] = value
+ break
+
+ return NodeRunResult(
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
+ outputs=outputs,
+ inputs=inputs
+ )
+
+ @classmethod
+ def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
+ return {}
diff --git a/api/core/workflow/nodes/variable_assigner/entities.py b/api/core/workflow/nodes/variable_assigner/entities.py
deleted file mode 100644
index 035618bd66..0000000000
--- a/api/core/workflow/nodes/variable_assigner/entities.py
+++ /dev/null
@@ -1,12 +0,0 @@
-
-
-from core.workflow.entities.base_node_data_entities import BaseNodeData
-
-
-class VariableAssignerNodeData(BaseNodeData):
- """
- Knowledge retrieval Node Data.
- """
- type: str = 'variable-assigner'
- output_type: str
- variables: list[list[str]]
diff --git a/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py b/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py
deleted file mode 100644
index d0a1c9789c..0000000000
--- a/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from typing import cast
-
-from core.workflow.entities.base_node_data_entities import BaseNodeData
-from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.nodes.base_node import BaseNode
-from core.workflow.nodes.variable_assigner.entities import VariableAssignerNodeData
-from models.workflow import WorkflowNodeExecutionStatus
-
-
-class VariableAssignerNode(BaseNode):
- _node_data_cls = VariableAssignerNodeData
- _node_type = NodeType.VARIABLE_ASSIGNER
-
- def _run(self, variable_pool: VariablePool) -> NodeRunResult:
- node_data: VariableAssignerNodeData = cast(self._node_data_cls, self.node_data)
- # Get variables
- outputs = {}
- inputs = {}
- for variable in node_data.variables:
- value = variable_pool.get_variable_value(variable)
-
- if value is not None:
- outputs = {
- "output": value
- }
-
- inputs = {
- '.'.join(variable[1:]): value
- }
- break
-
- return NodeRunResult(
- status=WorkflowNodeExecutionStatus.SUCCEEDED,
- outputs=outputs,
- inputs=inputs
- )
-
- @classmethod
- def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
- return {}
diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py
index 9390ffa2a4..afd19abfde 100644
--- a/api/core/workflow/workflow_engine_manager.py
+++ b/api/core/workflow/workflow_engine_manager.py
@@ -2,8 +2,11 @@ import logging
import time
from typing import Optional, cast
+from flask import current_app
+
from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
+from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
@@ -11,19 +14,22 @@ from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.nodes.answer.answer_node import AnswerNode
-from core.workflow.nodes.base_node import BaseNode, UserFrom
+from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
+from core.workflow.nodes.iteration.entities import IterationState
+from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.llm.llm_node import LLMNode
+from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
-from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode
+from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from extensions.ext_database import db
from models.workflow import (
Workflow,
@@ -42,7 +48,10 @@ node_classes = {
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
- NodeType.VARIABLE_ASSIGNER: VariableAssignerNode,
+ NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
+ NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode,
+ NodeType.ITERATION: IterationNode,
+ NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}
logger = logging.getLogger(__name__)
@@ -81,18 +90,20 @@ class WorkflowEngineManager:
def run_workflow(self, workflow: Workflow,
user_id: str,
user_from: UserFrom,
+ invoke_from: InvokeFrom,
user_inputs: dict,
system_inputs: Optional[dict] = None,
- callbacks: list[BaseWorkflowCallback] = None) -> None:
+ callbacks: list[BaseWorkflowCallback] = None,
+ call_depth: Optional[int] = 0,
+ variable_pool: Optional[VariablePool] = None) -> None:
"""
- Run workflow
:param workflow: Workflow instance
:param user_id: user id
:param user_from: user from
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks
- :return:
+ :param call_depth: call depth
"""
# fetch workflow graph
graph = workflow.graph_dict
@@ -107,54 +118,185 @@ class WorkflowEngineManager:
if not isinstance(graph.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
+
+ # init variable pool
+ if not variable_pool:
+ variable_pool = VariablePool(
+ system_variables=system_inputs,
+ user_inputs=user_inputs
+ )
+
+ workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH")
+ if call_depth > workflow_call_max_depth:
+ raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
+
+ # init workflow run state
+ workflow_run_state = WorkflowRunState(
+ workflow=workflow,
+ start_at=time.perf_counter(),
+ variable_pool=variable_pool,
+ user_id=user_id,
+ user_from=user_from,
+ invoke_from=invoke_from,
+ workflow_call_depth=call_depth
+ )
# init workflow run
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started()
- # init workflow run state
- workflow_run_state = WorkflowRunState(
+ # run workflow
+ self._run_workflow(
workflow=workflow,
- start_at=time.perf_counter(),
- variable_pool=VariablePool(
- system_variables=system_inputs,
- user_inputs=user_inputs
- ),
- user_id=user_id,
- user_from=user_from
+ workflow_run_state=workflow_run_state,
+ callbacks=callbacks,
)
+ def _run_workflow(self, workflow: Workflow,
+ workflow_run_state: WorkflowRunState,
+ callbacks: list[BaseWorkflowCallback] = None,
+ start_at: Optional[str] = None,
+ end_at: Optional[str] = None) -> None:
+ """
+ Run workflow
+ :param workflow: Workflow instance
+ :param user_id: user id
+ :param user_from: user from
+ :param user_inputs: user variables inputs
+ :param system_inputs: system inputs, like: query, files
+ :param callbacks: workflow callbacks
+ :param call_depth: call depth
+ :param start_at: force specific start node
+ :param end_at: force specific end node
+ :return:
+ """
+ graph = workflow.graph_dict
+
try:
- predecessor_node = None
+ predecessor_node: BaseNode = None
+ current_iteration_node: BaseIterationNode = None
has_entry_node = False
+ max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
+ max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
while True:
# get next node, multiple target nodes in the future
- next_node = self._get_next_node(
+ next_node = self._get_next_overall_node(
workflow_run_state=workflow_run_state,
graph=graph,
predecessor_node=predecessor_node,
- callbacks=callbacks
+ callbacks=callbacks,
+ start_at=start_at,
+ end_at=end_at
)
+ if not next_node:
+ # reached loop/iteration end or overall end
+ if current_iteration_node and workflow_run_state.current_iteration_state:
+ # reached loop/iteration end
+ # get next iteration
+ next_iteration = current_iteration_node.get_next_iteration(
+ variable_pool=workflow_run_state.variable_pool,
+ state=workflow_run_state.current_iteration_state
+ )
+ self._workflow_iteration_next(
+ graph=graph,
+ current_iteration_node=current_iteration_node,
+ workflow_run_state=workflow_run_state,
+ callbacks=callbacks
+ )
+ if isinstance(next_iteration, NodeRunResult):
+ if next_iteration.outputs:
+ for variable_key, variable_value in next_iteration.outputs.items():
+ # append variables to variable pool recursively
+ self._append_variables_recursively(
+ variable_pool=workflow_run_state.variable_pool,
+ node_id=current_iteration_node.node_id,
+ variable_key_list=[variable_key],
+ variable_value=variable_value
+ )
+ self._workflow_iteration_completed(
+ current_iteration_node=current_iteration_node,
+ workflow_run_state=workflow_run_state,
+ callbacks=callbacks
+ )
+ # iteration has ended
+ next_node = self._get_next_overall_node(
+ workflow_run_state=workflow_run_state,
+ graph=graph,
+ predecessor_node=current_iteration_node,
+ callbacks=callbacks,
+ start_at=start_at,
+ end_at=end_at
+ )
+ current_iteration_node = None
+ workflow_run_state.current_iteration_state = None
+ # continue overall process
+ elif isinstance(next_iteration, str):
+ # move to next iteration
+ next_node_id = next_iteration
+ # get next id
+ next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
+
if not next_node:
break
# check is already ran
- if next_node.node_id in [node_and_result.node.node_id
- for node_and_result in workflow_run_state.workflow_nodes_and_results]:
+ if self._check_node_has_ran(workflow_run_state, next_node.node_id):
predecessor_node = next_node
continue
has_entry_node = True
- # max steps 30 reached
- if len(workflow_run_state.workflow_nodes_and_results) > 30:
- raise ValueError('Max steps 30 reached.')
+ # max steps reached
+ if workflow_run_state.workflow_node_steps > max_execution_steps:
+ raise ValueError('Max steps {} reached.'.format(max_execution_steps))
- # or max execution time 10min reached
- if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=600):
- raise ValueError('Max execution time 10min reached.')
+ # or max execution time reached
+ if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time):
+ raise ValueError('Max execution time {}s reached.'.format(max_execution_time))
+
+ # handle iteration nodes
+ if isinstance(next_node, BaseIterationNode):
+ current_iteration_node = next_node
+ workflow_run_state.current_iteration_state = next_node.run(
+ variable_pool=workflow_run_state.variable_pool
+ )
+ self._workflow_iteration_started(
+ graph=graph,
+ current_iteration_node=current_iteration_node,
+ workflow_run_state=workflow_run_state,
+ predecessor_node_id=predecessor_node.node_id if predecessor_node else None,
+ callbacks=callbacks
+ )
+ predecessor_node = next_node
+ # move to start node of iteration
+ next_node_id = next_node.get_next_iteration(
+ variable_pool=workflow_run_state.variable_pool,
+ state=workflow_run_state.current_iteration_state
+ )
+ self._workflow_iteration_next(
+ graph=graph,
+ current_iteration_node=current_iteration_node,
+ workflow_run_state=workflow_run_state,
+ callbacks=callbacks
+ )
+ if isinstance(next_node_id, NodeRunResult):
+ # iteration has ended
+ current_iteration_node.set_output(
+ variable_pool=workflow_run_state.variable_pool,
+ state=workflow_run_state.current_iteration_state
+ )
+ self._workflow_iteration_completed(
+ current_iteration_node=current_iteration_node,
+ workflow_run_state=workflow_run_state,
+ callbacks=callbacks
+ )
+ current_iteration_node = None
+ workflow_run_state.current_iteration_state = None
+ continue
+ else:
+ next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
# run workflow, run multiple target nodes in the future
self._run_workflow_node(
@@ -231,7 +373,9 @@ class WorkflowEngineManager:
workflow_id=workflow.id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
- config=node_config
+ invoke_from=InvokeFrom.DEBUGGER,
+ config=node_config,
+ workflow_call_depth=0
)
try:
@@ -247,49 +391,14 @@ class WorkflowEngineManager:
except NotImplementedError:
variable_mapping = {}
- for variable_key, variable_selector in variable_mapping.items():
- if variable_key not in user_inputs:
- raise ValueError(f'Variable key {variable_key} not found in user inputs.')
-
- # fetch variable node id from variable selector
- variable_node_id = variable_selector[0]
- variable_key_list = variable_selector[1:]
-
- # get value
- value = user_inputs.get(variable_key)
-
- # temp fix for image type
- if node_type == NodeType.LLM:
- new_value = []
- if isinstance(value, list):
- node_data = node_instance.node_data
- node_data = cast(LLMNodeData, node_data)
-
- detail = node_data.vision.configs.detail if node_data.vision.configs else None
-
- for item in value:
- if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
- transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
- file = FileVar(
- tenant_id=workflow.tenant_id,
- type=FileType.IMAGE,
- transfer_method=transfer_method,
- url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
- related_id=item.get(
- 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
- extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
- )
- new_value.append(file)
-
- if new_value:
- value = new_value
-
- # append variable and value to variable pool
- variable_pool.append_variable(
- node_id=variable_node_id,
- variable_key_list=variable_key_list,
- value=value
- )
+ self._mapping_user_inputs_to_variable_pool(
+ variable_mapping=variable_mapping,
+ user_inputs=user_inputs,
+ variable_pool=variable_pool,
+ tenant_id=workflow.tenant_id,
+ node_instance=node_instance
+ )
+
# run node
node_run_result = node_instance.run(
variable_pool=variable_pool
@@ -307,6 +416,126 @@ class WorkflowEngineManager:
return node_instance, node_run_result
+ def single_step_run_iteration_workflow_node(self, workflow: Workflow,
+ node_id: str,
+ user_id: str,
+ user_inputs: dict,
+ callbacks: list[BaseWorkflowCallback] = None,
+ ) -> None:
+ """
+ Single iteration run workflow node
+ """
+ # fetch node info from workflow graph
+ graph = workflow.graph_dict
+ if not graph:
+ raise ValueError('workflow graph not found')
+
+ nodes = graph.get('nodes')
+ if not nodes:
+ raise ValueError('nodes not found in workflow graph')
+
+ for node in nodes:
+ if node.get('id') == node_id:
+ if node.get('data', {}).get('type') in [
+ NodeType.ITERATION.value,
+ NodeType.LOOP.value,
+ ]:
+ node_config = node
+ else:
+ raise ValueError('node id is not an iteration node')
+
+ # init variable pool
+ variable_pool = VariablePool(
+ system_variables={},
+ user_inputs={}
+ )
+
+ # variable selector to variable mapping
+ iteration_nested_nodes = [
+ node for node in nodes
+ if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id
+ ]
+ iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes]
+
+ if not iteration_nested_nodes:
+ raise ValueError('iteration has no nested nodes')
+
+ # init workflow run
+ if callbacks:
+ for callback in callbacks:
+ callback.on_workflow_run_started()
+
+ for node_config in iteration_nested_nodes:
+ # mapping user inputs to variable pool
+ node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
+ try:
+ variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config)
+ except NotImplementedError:
+ variable_mapping = {}
+
+ # remove iteration variables
+ variable_mapping = {
+ f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items()
+ if value[0] != node_id
+ }
+
+ # remove variable out from iteration
+ variable_mapping = {
+ key: value for key, value in variable_mapping.items()
+ if value[0] not in iteration_nested_node_ids
+ }
+
+ # append variables to variable pool
+ node_instance = node_cls(
+ tenant_id=workflow.tenant_id,
+ app_id=workflow.app_id,
+ workflow_id=workflow.id,
+ user_id=user_id,
+ user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.DEBUGGER,
+ config=node_config,
+ callbacks=callbacks,
+ workflow_call_depth=0
+ )
+
+ self._mapping_user_inputs_to_variable_pool(
+ variable_mapping=variable_mapping,
+ user_inputs=user_inputs,
+ variable_pool=variable_pool,
+ tenant_id=workflow.tenant_id,
+ node_instance=node_instance
+ )
+
+ # fetch end node of iteration
+ end_node_id = None
+ for edge in graph.get('edges'):
+ if edge.get('source') == node_id:
+ end_node_id = edge.get('target')
+ break
+
+ if not end_node_id:
+ raise ValueError('end node of iteration not found')
+
+ # init workflow run state
+ workflow_run_state = WorkflowRunState(
+ workflow=workflow,
+ start_at=time.perf_counter(),
+ variable_pool=variable_pool,
+ user_id=user_id,
+ user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.DEBUGGER,
+ workflow_call_depth=0
+ )
+
+ # run workflow
+ self._run_workflow(
+ workflow=workflow,
+ workflow_run_state=workflow_run_state,
+ callbacks=callbacks,
+ start_at=node_id,
+ end_at=end_node_id
+ )
+
def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Workflow run success
@@ -332,10 +561,96 @@ class WorkflowEngineManager:
error=error
)
- def _get_next_node(self, workflow_run_state: WorkflowRunState,
+ def _workflow_iteration_started(self, graph: dict,
+ current_iteration_node: BaseIterationNode,
+ workflow_run_state: WorkflowRunState,
+ predecessor_node_id: Optional[str] = None,
+ callbacks: list[BaseWorkflowCallback] = None) -> None:
+ """
+ Workflow iteration started
+ :param current_iteration_node: current iteration node
+ :param workflow_run_state: workflow run state
+ :param callbacks: workflow callbacks
+ :return:
+ """
+ # get nested nodes
+ iteration_nested_nodes = [
+ node for node in graph.get('nodes')
+ if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id
+ ]
+
+ if not iteration_nested_nodes:
+ raise ValueError('iteration has no nested nodes')
+
+ if callbacks:
+ if isinstance(workflow_run_state.current_iteration_state, IterationState):
+ for callback in callbacks:
+ callback.on_workflow_iteration_started(
+ node_id=current_iteration_node.node_id,
+ node_type=NodeType.ITERATION,
+ node_run_index=workflow_run_state.workflow_node_steps,
+ node_data=current_iteration_node.node_data,
+ inputs=workflow_run_state.current_iteration_state.inputs,
+ predecessor_node_id=predecessor_node_id,
+ metadata=workflow_run_state.current_iteration_state.metadata.dict()
+ )
+
+ # add steps
+ workflow_run_state.workflow_node_steps += 1
+
+ def _workflow_iteration_next(self, graph: dict,
+ current_iteration_node: BaseIterationNode,
+ workflow_run_state: WorkflowRunState,
+ callbacks: list[BaseWorkflowCallback] = None) -> None:
+ """
+ Workflow iteration next
+ :param workflow_run_state: workflow run state
+ :return:
+ """
+ if callbacks:
+ if isinstance(workflow_run_state.current_iteration_state, IterationState):
+ for callback in callbacks:
+ callback.on_workflow_iteration_next(
+ node_id=current_iteration_node.node_id,
+ node_type=NodeType.ITERATION,
+ index=workflow_run_state.current_iteration_state.index,
+ node_run_index=workflow_run_state.workflow_node_steps,
+ output=workflow_run_state.current_iteration_state.get_current_output()
+ )
+ # clear ran nodes
+ workflow_run_state.workflow_node_runs = [
+ node_run for node_run in workflow_run_state.workflow_node_runs
+ if node_run.iteration_node_id != current_iteration_node.node_id
+ ]
+
+ # clear variables in current iteration
+ nodes = graph.get('nodes')
+ nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id]
+
+ for node in nodes:
+ workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id'))
+
+ def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode,
+ workflow_run_state: WorkflowRunState,
+ callbacks: list[BaseWorkflowCallback] = None) -> None:
+ if callbacks:
+ if isinstance(workflow_run_state.current_iteration_state, IterationState):
+ for callback in callbacks:
+ callback.on_workflow_iteration_completed(
+ node_id=current_iteration_node.node_id,
+ node_type=NodeType.ITERATION,
+ node_run_index=workflow_run_state.workflow_node_steps,
+ outputs={
+ 'output': workflow_run_state.current_iteration_state.outputs
+ }
+ )
+
+ def _get_next_overall_node(self, workflow_run_state: WorkflowRunState,
graph: dict,
predecessor_node: Optional[BaseNode] = None,
- callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]:
+ callbacks: list[BaseWorkflowCallback] = None,
+ start_at: Optional[str] = None,
+ end_at: Optional[str] = None) -> Optional[BaseNode]:
"""
Get next node
multiple target nodes in the future.
@@ -350,16 +665,26 @@ class WorkflowEngineManager:
if not predecessor_node:
for node_config in nodes:
- if node_config.get('data', {}).get('type', '') == NodeType.START.value:
- return StartNode(
+ node_cls = None
+ if start_at:
+ if node_config.get('id') == start_at:
+ node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
+ else:
+ if node_config.get('data', {}).get('type', '') == NodeType.START.value:
+ node_cls = StartNode
+ if node_cls:
+ return node_cls(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from,
+ invoke_from=workflow_run_state.invoke_from,
config=node_config,
- callbacks=callbacks
+ callbacks=callbacks,
+ workflow_call_depth=workflow_run_state.workflow_call_depth
)
+
else:
edges = graph.get('edges')
source_node_id = predecessor_node.node_id
@@ -386,6 +711,9 @@ class WorkflowEngineManager:
target_node_id = outgoing_edge.get('target')
+ if end_at and target_node_id == end_at:
+ return None
+
# fetch target node from target node id
target_node_config = None
for node in nodes:
@@ -405,9 +733,40 @@ class WorkflowEngineManager:
workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from,
+ invoke_from=workflow_run_state.invoke_from,
config=target_node_config,
- callbacks=callbacks
+ callbacks=callbacks,
+ workflow_call_depth=workflow_run_state.workflow_call_depth
)
+
+ def _get_node(self, workflow_run_state: WorkflowRunState,
+ graph: dict,
+ node_id: str,
+ callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]:
+ """
+ Get node from graph by node id
+ """
+ nodes = graph.get('nodes')
+ if not nodes:
+ return None
+
+ for node_config in nodes:
+ if node_config.get('id') == node_id:
+ node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
+ node_cls = node_classes.get(node_type)
+ return node_cls(
+ tenant_id=workflow_run_state.tenant_id,
+ app_id=workflow_run_state.app_id,
+ workflow_id=workflow_run_state.workflow_id,
+ user_id=workflow_run_state.user_id,
+ user_from=workflow_run_state.user_from,
+ invoke_from=workflow_run_state.invoke_from,
+ config=node_config,
+ callbacks=callbacks,
+ workflow_call_depth=workflow_run_state.workflow_call_depth
+ )
+
+ return None
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
@@ -418,6 +777,15 @@ class WorkflowEngineManager:
"""
return time.perf_counter() - start_at > max_execution_time
+ def _check_node_has_ran(self, workflow_run_state: WorkflowRunState, node_id: str) -> bool:
+ """
+ Check node has ran
+ """
+ return bool([
+ node_and_result for node_and_result in workflow_run_state.workflow_node_runs
+ if node_and_result.node_id == node_id
+ ])
+
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
node: BaseNode,
predecessor_node: Optional[BaseNode] = None,
@@ -428,7 +796,7 @@ class WorkflowEngineManager:
node_id=node.node_id,
node_type=node.node_type,
node_data=node.node_data,
- node_run_index=len(workflow_run_state.workflow_nodes_and_results) + 1,
+ node_run_index=workflow_run_state.workflow_node_steps,
predecessor_node_id=predecessor_node.node_id if predecessor_node else None
)
@@ -442,6 +810,16 @@ class WorkflowEngineManager:
# add to workflow_nodes_and_results
workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result)
+ # add steps
+ workflow_run_state.workflow_node_steps += 1
+
+ # mark node as running
+ if workflow_run_state.current_iteration_state:
+ workflow_run_state.workflow_node_runs.append(WorkflowRunState.NodeRun(
+ node_id=node.node_id,
+ iteration_node_id=workflow_run_state.current_iteration_state.iteration_node_id
+ ))
+
try:
# run node, result must have inputs, process_data, outputs, execution_metadata
node_run_result = node.run(
@@ -561,3 +939,53 @@ class WorkflowEngineManager:
new_value[key] = new_val
return new_value
+
+ def _mapping_user_inputs_to_variable_pool(self,
+ variable_mapping: dict,
+ user_inputs: dict,
+ variable_pool: VariablePool,
+ tenant_id: str,
+ node_instance: BaseNode):
+ for variable_key, variable_selector in variable_mapping.items():
+ if variable_key not in user_inputs:
+ raise ValueError(f'Variable key {variable_key} not found in user inputs.')
+
+ # fetch variable node id from variable selector
+ variable_node_id = variable_selector[0]
+ variable_key_list = variable_selector[1:]
+
+ # get value
+ value = user_inputs.get(variable_key)
+
+ # temp fix for image type
+ if node_instance.node_type == NodeType.LLM:
+ new_value = []
+ if isinstance(value, list):
+ node_data = node_instance.node_data
+ node_data = cast(LLMNodeData, node_data)
+
+ detail = node_data.vision.configs.detail if node_data.vision.configs else None
+
+ for item in value:
+ if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
+ transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
+ file = FileVar(
+ tenant_id=tenant_id,
+ type=FileType.IMAGE,
+ transfer_method=transfer_method,
+ url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
+ related_id=item.get(
+ 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
+ extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
+ )
+ new_value.append(file)
+
+ if new_value:
+ value = new_value
+
+ # append variable and value to variable pool
+ variable_pool.append_variable(
+ node_id=variable_node_id,
+ variable_key_list=variable_key_list,
+ value=value
+ )
\ No newline at end of file
diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py
index 688b80aa8c..ceb50a252b 100644
--- a/api/events/event_handlers/__init__.py
+++ b/api/events/event_handlers/__init__.py
@@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle
from .deduct_quota_when_messaeg_created import handle
from .delete_installed_app_when_app_deleted import handle
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
+from .delete_workflow_as_tool_when_app_deleted import handle
from .update_app_dataset_join_when_app_model_config_updated import handle
from .update_app_dataset_join_when_app_published_workflow_updated import handle
from .update_provider_last_used_at_when_messaeg_created import handle
diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
index 2a127d903e..1f6da34ee2 100644
--- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
+++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
@@ -10,18 +10,22 @@ def handle(sender, **kwargs):
app = sender
for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []):
if node_data.get('data', {}).get('type') == NodeType.TOOL.value:
- tool_entity = ToolEntity(**node_data["data"])
- tool_runtime = ToolManager.get_tool_runtime(
- provider_type=tool_entity.provider_type,
- provider_name=tool_entity.provider_id,
- tool_name=tool_entity.tool_name,
- tenant_id=app.tenant_id,
- )
- manager = ToolParameterConfigurationManager(
- tenant_id=app.tenant_id,
- tool_runtime=tool_runtime,
- provider_name=tool_entity.provider_name,
- provider_type=tool_entity.provider_type,
- identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
- )
- manager.delete_tool_parameters_cache()
+ try:
+ tool_entity = ToolEntity(**node_data["data"])
+ tool_runtime = ToolManager.get_tool_runtime(
+ provider_type=tool_entity.provider_type,
+ provider_id=tool_entity.provider_id,
+ tool_name=tool_entity.tool_name,
+ tenant_id=app.tenant_id,
+ )
+ manager = ToolParameterConfigurationManager(
+ tenant_id=app.tenant_id,
+ tool_runtime=tool_runtime,
+ provider_name=tool_entity.provider_name,
+ provider_type=tool_entity.provider_type,
+ identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
+ )
+ manager.delete_tool_parameters_cache()
+ except:
+ # tool dose not exist
+ pass
diff --git a/api/events/event_handlers/delete_workflow_as_tool_when_app_deleted.py b/api/events/event_handlers/delete_workflow_as_tool_when_app_deleted.py
new file mode 100644
index 0000000000..0c56688ff6
--- /dev/null
+++ b/api/events/event_handlers/delete_workflow_as_tool_when_app_deleted.py
@@ -0,0 +1,14 @@
+from events.app_event import app_was_deleted
+from extensions.ext_database import db
+from models.tools import WorkflowToolProvider
+
+
+@app_was_deleted.connect
+def handle(sender, **kwargs):
+ app = sender
+ workflow_tools = db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.app_id == app.id
+ ).all()
+ for workflow_tool in workflow_tools:
+ db.session.delete(workflow_tool)
+ db.session.commit()
diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py
index d2c6e32dfd..ec3a5cc112 100644
--- a/api/extensions/ext_mail.py
+++ b/api/extensions/ext_mail.py
@@ -1,3 +1,4 @@
+import logging
from typing import Optional
import resend
@@ -16,7 +17,7 @@ class Mail:
if app.config.get('MAIL_TYPE'):
if app.config.get('MAIL_DEFAULT_SEND_FROM'):
self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM')
-
+
if app.config.get('MAIL_TYPE') == 'resend':
api_key = app.config.get('RESEND_API_KEY')
if not api_key:
@@ -32,16 +33,22 @@ class Mail:
from libs.smtp import SMTPClient
if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'):
raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type')
+ if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'):
+ raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS')
self._client = SMTPClient(
server=app.config.get('SMTP_SERVER'),
port=app.config.get('SMTP_PORT'),
username=app.config.get('SMTP_USERNAME'),
password=app.config.get('SMTP_PASSWORD'),
_from=app.config.get('MAIL_DEFAULT_SEND_FROM'),
- use_tls=app.config.get('SMTP_USE_TLS')
+ use_tls=app.config.get('SMTP_USE_TLS'),
+ opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS')
)
else:
raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE')))
+ else:
+ logging.warning('MAIL_TYPE is not set')
+
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
if not self._client:
diff --git a/api/extensions/storage/azure_storage.py b/api/extensions/storage/azure_storage.py
index 01de8bab94..b9809de640 100644
--- a/api/extensions/storage/azure_storage.py
+++ b/api/extensions/storage/azure_storage.py
@@ -5,38 +5,39 @@ from datetime import datetime, timedelta, timezone
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
from flask import Flask
+from extensions.ext_redis import redis_client
from extensions.storage.base_storage import BaseStorage
class AzureStorage(BaseStorage):
"""Implementation for azure storage.
"""
+
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
- self.bucket_name = app_config.get('AZURE_STORAGE_CONTAINER_NAME')
- sas_token = generate_account_sas(
- account_name=app_config.get('AZURE_BLOB_ACCOUNT_NAME'),
- account_key=app_config.get('AZURE_BLOB_ACCOUNT_KEY'),
- resource_types=ResourceTypes(service=True, container=True, object=True),
- permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
- expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1)
- )
- self.client = BlobServiceClient(account_url=app_config.get('AZURE_BLOB_ACCOUNT_URL'),
- credential=sas_token)
+ self.bucket_name = app_config.get('AZURE_BLOB_CONTAINER_NAME')
+ self.account_url = app_config.get('AZURE_BLOB_ACCOUNT_URL')
+ self.account_name = app_config.get('AZURE_BLOB_ACCOUNT_NAME')
+ self.account_key = app_config.get('AZURE_BLOB_ACCOUNT_KEY')
+
def save(self, filename, data):
- blob_container = self.client.get_container_client(container=self.bucket_name)
+ client = self._sync_client()
+ blob_container = client.get_container_client(container=self.bucket_name)
blob_container.upload_blob(filename, data)
def load_once(self, filename: str) -> bytes:
- blob = self.client.get_container_client(container=self.bucket_name)
+ client = self._sync_client()
+ blob = client.get_container_client(container=self.bucket_name)
blob = blob.get_blob_client(blob=filename)
data = blob.download_blob().readall()
return data
def load_stream(self, filename: str) -> Generator:
+ client = self._sync_client()
+
def generate(filename: str = filename) -> Generator:
- blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
+ blob = client.get_blob_client(container=self.bucket_name, blob=filename)
with closing(blob.download_blob()) as blob_stream:
while chunk := blob_stream.readall(4096):
yield chunk
@@ -44,15 +45,37 @@ class AzureStorage(BaseStorage):
return generate()
def download(self, filename, target_filepath):
- blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
+ client = self._sync_client()
+
+ blob = client.get_blob_client(container=self.bucket_name, blob=filename)
with open(target_filepath, "wb") as my_blob:
blob_data = blob.download_blob()
blob_data.readinto(my_blob)
def exists(self, filename):
- blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
+ client = self._sync_client()
+
+ blob = client.get_blob_client(container=self.bucket_name, blob=filename)
return blob.exists()
def delete(self, filename):
- blob_container = self.client.get_container_client(container=self.bucket_name)
- blob_container.delete_blob(filename)
\ No newline at end of file
+ client = self._sync_client()
+
+ blob_container = client.get_container_client(container=self.bucket_name)
+ blob_container.delete_blob(filename)
+
+ def _sync_client(self):
+ cache_key = 'azure_blob_sas_token_{}_{}'.format(self.account_name, self.account_key)
+ cache_result = redis_client.get(cache_key)
+ if cache_result is not None:
+ sas_token = cache_result.decode('utf-8')
+ else:
+ sas_token = generate_account_sas(
+ account_name=self.account_name,
+ account_key=self.account_key,
+ resource_types=ResourceTypes(service=True, container=True, object=True),
+ permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
+ expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1)
+ )
+ redis_client.set(cache_key, sas_token, ex=3000)
+ return BlobServiceClient(account_url=self.account_url, credential=sas_token)
diff --git a/api/extensions/storage/google_storage.py b/api/extensions/storage/google_storage.py
index f6c69eb0ae..97004fddab 100644
--- a/api/extensions/storage/google_storage.py
+++ b/api/extensions/storage/google_storage.py
@@ -1,4 +1,5 @@
import base64
+import io
from collections.abc import Generator
from contextlib import closing
@@ -15,14 +16,19 @@ class GoogleStorage(BaseStorage):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get('GOOGLE_STORAGE_BUCKET_NAME')
- service_account_json = base64.b64decode(app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')).decode(
- 'utf-8')
- self.client = GoogleCloudStorage.Client().from_service_account_json(service_account_json)
+ service_account_json_str = app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')
+ # if service_account_json_str is empty, use Application Default Credentials
+ if service_account_json_str:
+ service_account_json = base64.b64decode(service_account_json_str).decode('utf-8')
+ self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_json)
+ else:
+ self.client = GoogleCloudStorage.Client()
def save(self, filename, data):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(filename)
- blob.upload_from_file(data)
+ with io.BytesIO(data) as stream:
+ blob.upload_from_file(stream)
def load_once(self, filename: str) -> bytes:
bucket = self.client.get_bucket(self.bucket_name)
diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py
index c7cfdd7939..212c3e7f17 100644
--- a/api/fields/app_fields.py
+++ b/api/fields/app_fields.py
@@ -113,6 +113,7 @@ site_fields = {
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
+ 'custom_disclaimer': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean,
'app_base_url': fields.String,
@@ -146,6 +147,7 @@ app_site_fields = {
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
+ 'custom_disclaimer': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean
}
diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py
index 94d905eafe..e8215255b3 100644
--- a/api/fields/document_fields.py
+++ b/api/fields/document_fields.py
@@ -8,6 +8,7 @@ document_fields = {
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
+ 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
@@ -31,6 +32,7 @@ document_with_segments_fields = {
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
+ 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index 9919a440e8..54d7ed55f8 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -7,8 +7,10 @@ workflow_fields = {
'id': fields.String,
'graph': fields.Raw(attribute='graph_dict'),
'features': fields.Raw(attribute='features_dict'),
+ 'hash': fields.String(attribute='unique_hash'),
'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'),
'created_at': TimestampField,
'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True),
- 'updated_at': TimestampField
+ 'updated_at': TimestampField,
+ 'tool_published': fields.Boolean,
}
diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py
index c22546f602..9856875c16 100644
--- a/api/libs/gmpy2_pkcs10aep_cipher.py
+++ b/api/libs/gmpy2_pkcs10aep_cipher.py
@@ -48,7 +48,7 @@ class PKCS1OAEP_Cipher:
`Crypto.Hash.SHA1` is used.
mgfunc : callable
A mask generation function that accepts two parameters: a string to
- use as seed, and the lenth of the mask to generate, in bytes.
+ use as seed, and the length of the mask to generate, in bytes.
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
label : bytes/bytearray/memoryview
A label to apply to this particular encryption. If not specified,
@@ -218,7 +218,7 @@ def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
:param mgfunc:
A mask generation function that accepts two parameters: a string to
- use as seed, and the lenth of the mask to generate, in bytes.
+ use as seed, and the length of the mask to generate, in bytes.
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
:type mgfunc: callable
diff --git a/api/libs/helper.py b/api/libs/helper.py
index f9cf590b7a..f4be9c5531 100644
--- a/api/libs/helper.py
+++ b/api/libs/helper.py
@@ -46,7 +46,13 @@ def uuid_value(value):
error = ('{value} is not a valid uuid.'
.format(value=value))
raise ValueError(error)
-
+
+def alphanumeric(value: str):
+ # check if the value is alphanumeric and underlined
+ if re.match(r'^[a-zA-Z0-9_]+$', value):
+ return value
+
+ raise ValueError(f'{value} is not a valid alphanumeric value')
def timestamp_value(timestamp):
try:
diff --git a/api/libs/smtp.py b/api/libs/smtp.py
index 30a795bd70..bf3a1a92e9 100644
--- a/api/libs/smtp.py
+++ b/api/libs/smtp.py
@@ -1,27 +1,50 @@
+import logging
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
class SMTPClient:
- def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False):
+ def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False):
self.server = server
self.port = port
self._from = _from
self.username = username
self.password = password
- self._use_tls = use_tls
+ self.use_tls = use_tls
+ self.opportunistic_tls = opportunistic_tls
def send(self, mail: dict):
- smtp = smtplib.SMTP(self.server, self.port)
- if self._use_tls:
- smtp.starttls()
- if self.username and self.password:
- smtp.login(self.username, self.password)
- msg = MIMEMultipart()
- msg['Subject'] = mail['subject']
- msg['From'] = self._from
- msg['To'] = mail['to']
- msg.attach(MIMEText(mail['html'], 'html'))
- smtp.sendmail(self.username, mail['to'], msg.as_string())
- smtp.quit()
+ smtp = None
+ try:
+ if self.use_tls:
+ if self.opportunistic_tls:
+ smtp = smtplib.SMTP(self.server, self.port, timeout=10)
+ smtp.starttls()
+ else:
+ smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
+ else:
+ smtp = smtplib.SMTP(self.server, self.port, timeout=10)
+
+ if self.username and self.password:
+ smtp.login(self.username, self.password)
+
+ msg = MIMEMultipart()
+ msg['Subject'] = mail['subject']
+ msg['From'] = self._from
+ msg['To'] = mail['to']
+ msg.attach(MIMEText(mail['html'], 'html'))
+
+ smtp.sendmail(self._from, mail['to'], msg.as_string())
+ except smtplib.SMTPException as e:
+ logging.error(f"SMTP error occurred: {str(e)}")
+ raise
+ except TimeoutError as e:
+ logging.error(f"Timeout occurred while sending email: {str(e)}")
+ raise
+ except Exception as e:
+ logging.error(f"Unexpected error occurred while sending email: {str(e)}")
+ raise
+ finally:
+ if smtp:
+ smtp.quit()
diff --git a/api/migrations/env.py b/api/migrations/env.py
index 18485c1885..ad3a122c04 100644
--- a/api/migrations/env.py
+++ b/api/migrations/env.py
@@ -110,3 +110,4 @@ if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
+
diff --git a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py
new file mode 100644
index 0000000000..0fba6a87eb
--- /dev/null
+++ b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py
@@ -0,0 +1,32 @@
+"""add workflow tool label and tool bindings idx
+
+Revision ID: 03f98355ba0e
+Revises: 9e98fbaffb88
+Create Date: 2024-05-25 07:17:00.539125
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '03f98355ba0e'
+down_revision = '9e98fbaffb88'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ with op.batch_alter_table('tool_label_bindings', schema=None) as batch_op:
+ batch_op.create_unique_constraint('unique_tool_label_bind', ['tool_id', 'label_name'])
+
+ with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('label', sa.String(length=255), server_default='', nullable=False))
+
+def downgrade():
+ with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op:
+ batch_op.drop_column('label')
+
+ with op.batch_alter_table('tool_label_bindings', schema=None) as batch_op:
+ batch_op.drop_constraint('unique_tool_label_bind', type_='unique')
diff --git a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py
new file mode 100644
index 0000000000..db3119badf
--- /dev/null
+++ b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py
@@ -0,0 +1,42 @@
+"""add tool label bings
+
+Revision ID: 3b18fea55204
+Revises: 7bdef072e63a
+Create Date: 2024-05-14 09:27:18.857890
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '3b18fea55204'
+down_revision = '7bdef072e63a'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('tool_label_bindings',
+ sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('tool_id', sa.String(length=64), nullable=False),
+ sa.Column('tool_type', sa.String(length=40), nullable=False),
+ sa.Column('label_name', sa.String(length=40), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='tool_label_bind_pkey')
+ )
+
+ with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('privacy_policy', sa.String(length=255), server_default='', nullable=True))
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op:
+ batch_op.drop_column('privacy_policy')
+
+ op.drop_table('tool_label_bindings')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py b/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py
new file mode 100644
index 0000000000..b37928d3c0
--- /dev/null
+++ b/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py
@@ -0,0 +1,39 @@
+"""modify default model name length
+
+Revision ID: 47cc7df8c4f3
+Revises: 3c7cac9521c6
+Create Date: 2024-05-10 09:48:09.046298
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '47cc7df8c4f3'
+down_revision = '3c7cac9521c6'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tenant_default_models', schema=None) as batch_op:
+ batch_op.alter_column('model_name',
+ existing_type=sa.VARCHAR(length=40),
+ type_=sa.String(length=255),
+ existing_nullable=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tenant_default_models', schema=None) as batch_op:
+ batch_op.alter_column('model_name',
+ existing_type=sa.String(length=255),
+ type_=sa.VARCHAR(length=40),
+ existing_nullable=False)
+
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py
new file mode 100644
index 0000000000..67d7b9fbf5
--- /dev/null
+++ b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py
@@ -0,0 +1,126 @@
+"""add load balancing
+
+Revision ID: 4e99a8df00ff
+Revises: 47cc7df8c4f3
+Create Date: 2024-05-10 12:08:09.812736
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '4e99a8df00ff'
+down_revision = '64a70a7aab8b'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('load_balancing_model_configs',
+ sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('tenant_id', models.StringUUID(), nullable=False),
+ sa.Column('provider_name', sa.String(length=255), nullable=False),
+ sa.Column('model_name', sa.String(length=255), nullable=False),
+ sa.Column('model_type', sa.String(length=40), nullable=False),
+ sa.Column('name', sa.String(length=255), nullable=False),
+ sa.Column('encrypted_config', sa.Text(), nullable=True),
+ sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey')
+ )
+ with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
+ batch_op.create_index('load_balancing_model_config_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False)
+
+ op.create_table('provider_model_settings',
+ sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('tenant_id', models.StringUUID(), nullable=False),
+ sa.Column('provider_name', sa.String(length=255), nullable=False),
+ sa.Column('model_name', sa.String(length=255), nullable=False),
+ sa.Column('model_type', sa.String(length=40), nullable=False),
+ sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
+ sa.Column('load_balancing_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='provider_model_setting_pkey')
+ )
+ with op.batch_alter_table('provider_model_settings', schema=None) as batch_op:
+ batch_op.create_index('provider_model_setting_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False)
+
+ with op.batch_alter_table('provider_models', schema=None) as batch_op:
+ batch_op.alter_column('provider_name',
+ existing_type=sa.VARCHAR(length=40),
+ type_=sa.String(length=255),
+ existing_nullable=False)
+
+ with op.batch_alter_table('provider_orders', schema=None) as batch_op:
+ batch_op.alter_column('provider_name',
+ existing_type=sa.VARCHAR(length=40),
+ type_=sa.String(length=255),
+ existing_nullable=False)
+
+ with op.batch_alter_table('providers', schema=None) as batch_op:
+ batch_op.alter_column('provider_name',
+ existing_type=sa.VARCHAR(length=40),
+ type_=sa.String(length=255),
+ existing_nullable=False)
+
+ with op.batch_alter_table('tenant_default_models', schema=None) as batch_op:
+ batch_op.alter_column('provider_name',
+ existing_type=sa.VARCHAR(length=40),
+ type_=sa.String(length=255),
+ existing_nullable=False)
+
+ with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op:
+ batch_op.alter_column('provider_name',
+ existing_type=sa.VARCHAR(length=40),
+ type_=sa.String(length=255),
+ existing_nullable=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op:
+ batch_op.alter_column('provider_name',
+ existing_type=sa.String(length=255),
+ type_=sa.VARCHAR(length=40),
+ existing_nullable=False)
+
+ with op.batch_alter_table('tenant_default_models', schema=None) as batch_op:
+ batch_op.alter_column('provider_name',
+ existing_type=sa.String(length=255),
+ type_=sa.VARCHAR(length=40),
+ existing_nullable=False)
+
+ with op.batch_alter_table('providers', schema=None) as batch_op:
+ batch_op.alter_column('provider_name',
+ existing_type=sa.String(length=255),
+ type_=sa.VARCHAR(length=40),
+ existing_nullable=False)
+
+ with op.batch_alter_table('provider_orders', schema=None) as batch_op:
+ batch_op.alter_column('provider_name',
+ existing_type=sa.String(length=255),
+ type_=sa.VARCHAR(length=40),
+ existing_nullable=False)
+
+ with op.batch_alter_table('provider_models', schema=None) as batch_op:
+ batch_op.alter_column('provider_name',
+ existing_type=sa.String(length=255),
+ type_=sa.VARCHAR(length=40),
+ existing_nullable=False)
+
+ with op.batch_alter_table('provider_model_settings', schema=None) as batch_op:
+ batch_op.drop_index('provider_model_setting_tenant_provider_model_idx')
+
+ op.drop_table('provider_model_settings')
+ with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
+ batch_op.drop_index('load_balancing_model_config_tenant_provider_model_idx')
+
+ op.drop_table('load_balancing_model_configs')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/5fda94355fce_custom_disclaimer.py b/api/migrations/versions/5fda94355fce_custom_disclaimer.py
new file mode 100644
index 0000000000..73bcdc4500
--- /dev/null
+++ b/api/migrations/versions/5fda94355fce_custom_disclaimer.py
@@ -0,0 +1,45 @@
+"""Custom Disclaimer
+
+Revision ID: 5fda94355fce
+Revises: 47cc7df8c4f3
+Create Date: 2024-05-10 20:04:45.806549
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '5fda94355fce'
+down_revision = '47cc7df8c4f3'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('custom_disclaimer', sa.String(length=255), nullable=True))
+
+ with op.batch_alter_table('sites', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('custom_disclaimer', sa.String(length=255), nullable=True))
+
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('custom_disclaimer', sa.String(length=255), nullable=True))
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.drop_column('custom_disclaimer')
+
+ with op.batch_alter_table('sites', schema=None) as batch_op:
+ batch_op.drop_column('custom_disclaimer')
+
+ with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
+ batch_op.drop_column('custom_disclaimer')
+
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py b/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py
new file mode 100644
index 0000000000..73242908f4
--- /dev/null
+++ b/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py
@@ -0,0 +1,32 @@
+"""add workflow run index
+
+Revision ID: 64a70a7aab8b
+Revises: 03f98355ba0e
+Create Date: 2024-05-28 12:32:00.276061
+
+"""
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '64a70a7aab8b'
+down_revision = '03f98355ba0e'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.create_index('workflow_run_tenant_app_sequence_idx', ['tenant_id', 'app_id', 'sequence_number'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.drop_index('workflow_run_tenant_app_sequence_idx')
+
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py
new file mode 100644
index 0000000000..67b61e5c76
--- /dev/null
+++ b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py
@@ -0,0 +1,42 @@
+"""add workflow tool
+
+Revision ID: 7bdef072e63a
+Revises: 5fda94355fce
+Create Date: 2024-05-04 09:47:19.366961
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '7bdef072e63a'
+down_revision = '5fda94355fce'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('tool_workflow_providers',
+ sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('name', sa.String(length=40), nullable=False),
+ sa.Column('icon', sa.String(length=255), nullable=False),
+ sa.Column('app_id', models.StringUUID(), nullable=False),
+ sa.Column('user_id', models.StringUUID(), nullable=False),
+ sa.Column('tenant_id', models.StringUUID(), nullable=False),
+ sa.Column('description', sa.Text(), nullable=False),
+ sa.Column('parameter_configuration', sa.Text(), server_default='[]', nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'),
+ sa.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'),
+ sa.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id')
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ op.drop_table('tool_workflow_providers')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py
new file mode 100644
index 0000000000..bfda7d619c
--- /dev/null
+++ b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py
@@ -0,0 +1,26 @@
+"""add workflow tool version
+
+Revision ID: 9e98fbaffb88
+Revises: 3b18fea55204
+Create Date: 2024-05-21 10:25:40.434162
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '9e98fbaffb88'
+down_revision = '3b18fea55204'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('version', sa.String(length=255), server_default='', nullable=False))
+
+def downgrade():
+ with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op:
+ batch_op.drop_column('version')
diff --git a/api/models/dataset.py b/api/models/dataset.py
index 01b068fa2a..7f98bbde15 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -1,8 +1,15 @@
+import base64
+import hashlib
+import hmac
import json
import logging
+import os
import pickle
+import re
+import time
from json import JSONDecodeError
+from flask import current_app
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
@@ -69,7 +76,8 @@ class Dataset(db.Model):
@property
def app_count(self):
- return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id).scalar()
+ return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id,
+ App.id == AppDatasetJoin.app_id).scalar()
@property
def document_count(self):
@@ -414,6 +422,26 @@ class DocumentSegment(db.Model):
DocumentSegment.position == self.position + 1
).first()
+ def get_sign_content(self):
+ pattern = r"/files/([a-f0-9\-]+)/image-preview"
+ text = self.content
+ match = re.search(pattern, text)
+
+ if match:
+ upload_file_id = match.group(1)
+ nonce = os.urandom(16).hex()
+ timestamp = str(int(time.time()))
+ data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
+ secret_key = current_app.config['SECRET_KEY'].encode()
+ sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
+ encoded_sign = base64.urlsafe_b64encode(sign).decode()
+
+ params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
+ replacement = r"\g<0>?{params}".format(params=params)
+ text = re.sub(pattern, replacement, text)
+ return text
+
+
class AppDatasetJoin(db.Model):
__tablename__ = 'app_dataset_joins'
diff --git a/api/models/model.py b/api/models/model.py
index 59b88eb3b1..657db5a5c2 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -100,7 +100,7 @@ class App(db.Model):
return None
@property
- def workflow(self):
+ def workflow(self) -> Optional['Workflow']:
if self.workflow_id:
from .workflow import Workflow
return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
@@ -435,6 +435,7 @@ class RecommendedApp(db.Model):
description = db.Column(db.JSON, nullable=False)
copyright = db.Column(db.String(255), nullable=False)
privacy_policy = db.Column(db.String(255), nullable=False)
+ custom_disclaimer = db.Column(db.String(255), nullable=True)
category = db.Column(db.String(255), nullable=False)
position = db.Column(db.Integer, nullable=False, default=0)
is_listed = db.Column(db.Boolean, nullable=False, default=True)
@@ -796,7 +797,7 @@ class Message(db.Model):
if message_file.transfer_method == 'local_file':
upload_file = (db.session.query(UploadFile)
.filter(
- UploadFile.id == message_file.related_id
+ UploadFile.id == message_file.upload_file_id
).first())
url = UploadFileParser.get_image_data(
@@ -1042,6 +1043,7 @@ class Site(db.Model):
default_language = db.Column(db.String(255), nullable=False)
copyright = db.Column(db.String(255))
privacy_policy = db.Column(db.String(255))
+ custom_disclaimer = db.Column(db.String(255), nullable=True)
customize_domain = db.Column(db.String(255))
customize_token_strategy = db.Column(db.String(255), nullable=False)
prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
diff --git a/api/models/provider.py b/api/models/provider.py
index 413e8f9d67..4c14c33f09 100644
--- a/api/models/provider.py
+++ b/api/models/provider.py
@@ -47,7 +47,7 @@ class Provider(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(40), nullable=False)
+ provider_name = db.Column(db.String(255), nullable=False)
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
encrypted_config = db.Column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
@@ -94,7 +94,7 @@ class ProviderModel(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(40), nullable=False)
+ provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True)
@@ -112,8 +112,8 @@ class TenantDefaultModel(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(40), nullable=False)
- model_name = db.Column(db.String(40), nullable=False)
+ provider_name = db.Column(db.String(255), nullable=False)
+ model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@@ -128,7 +128,7 @@ class TenantPreferredModelProvider(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(40), nullable=False)
+ provider_name = db.Column(db.String(255), nullable=False)
preferred_provider_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@@ -143,7 +143,7 @@ class ProviderOrder(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(40), nullable=False)
+ provider_name = db.Column(db.String(255), nullable=False)
account_id = db.Column(StringUUID, nullable=False)
payment_product_id = db.Column(db.String(191), nullable=False)
payment_id = db.Column(db.String(191))
@@ -157,3 +157,46 @@ class ProviderOrder(db.Model):
refunded_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+
+
+class ProviderModelSetting(db.Model):
+ """
+ Provider model settings for record the model enabled status and load balancing status.
+ """
+ __tablename__ = 'provider_model_settings'
+ __table_args__ = (
+ db.PrimaryKeyConstraint('id', name='provider_model_setting_pkey'),
+ db.Index('provider_model_setting_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'),
+ )
+
+ id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+ tenant_id = db.Column(StringUUID, nullable=False)
+ provider_name = db.Column(db.String(255), nullable=False)
+ model_name = db.Column(db.String(255), nullable=False)
+ model_type = db.Column(db.String(40), nullable=False)
+ enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true'))
+ load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
+ created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+ updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+
+
+class LoadBalancingModelConfig(db.Model):
+ """
+ Configurations for load balancing models.
+ """
+ __tablename__ = 'load_balancing_model_configs'
+ __table_args__ = (
+ db.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey'),
+ db.Index('load_balancing_model_config_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'),
+ )
+
+ id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+ tenant_id = db.Column(StringUUID, nullable=False)
+ provider_name = db.Column(db.String(255), nullable=False)
+ model_name = db.Column(db.String(255), nullable=False)
+ model_type = db.Column(db.String(40), nullable=False)
+ name = db.Column(db.String(255), nullable=False)
+ encrypted_config = db.Column(db.Text, nullable=True)
+ enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true'))
+ created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+ updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
diff --git a/api/models/tools.py b/api/models/tools.py
index 8a133679e0..49212916ec 100644
--- a/api/models/tools.py
+++ b/api/models/tools.py
@@ -3,8 +3,8 @@ import json
from sqlalchemy import ForeignKey
from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_bundle import ApiBasedToolBundle
-from core.tools.entities.tool_entities import ApiProviderSchemaType
+from core.tools.entities.tool_bundle import ApiToolBundle
+from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from extensions.ext_database import db
from models import StringUUID
from models.model import Account, App, Tenant
@@ -107,6 +107,8 @@ class ApiToolProvider(db.Model):
credentials_str = db.Column(db.Text, nullable=False)
# privacy policy
privacy_policy = db.Column(db.String(255), nullable=True)
+ # custom_disclaimer
+ custom_disclaimer = db.Column(db.String(255), nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@@ -116,8 +118,8 @@ class ApiToolProvider(db.Model):
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property
- def tools(self) -> list[ApiBasedToolBundle]:
- return [ApiBasedToolBundle(**tool) for tool in json.loads(self.tools_str)]
+ def tools(self) -> list[ApiToolBundle]:
+ return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
@property
def credentials(self) -> dict:
@@ -130,7 +132,84 @@ class ApiToolProvider(db.Model):
@property
def tenant(self) -> Tenant:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
+
+class ToolLabelBinding(db.Model):
+ """
+ The table stores the labels for tools.
+ """
+ __tablename__ = 'tool_label_bindings'
+ __table_args__ = (
+ db.PrimaryKeyConstraint('id', name='tool_label_bind_pkey'),
+ db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'),
+ )
+
+ id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+ # tool id
+ tool_id = db.Column(db.String(64), nullable=False)
+ # tool type
+ tool_type = db.Column(db.String(40), nullable=False)
+ # label name
+ label_name = db.Column(db.String(40), nullable=False)
+
+class WorkflowToolProvider(db.Model):
+ """
+ The table stores the workflow providers.
+ """
+ __tablename__ = 'tool_workflow_providers'
+ __table_args__ = (
+ db.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'),
+ db.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'),
+ db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'),
+ )
+
+ id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+ # name of the workflow provider
+ name = db.Column(db.String(40), nullable=False)
+ # label of the workflow provider
+ label = db.Column(db.String(255), nullable=False, server_default='')
+ # icon
+ icon = db.Column(db.String(255), nullable=False)
+ # app id of the workflow provider
+ app_id = db.Column(StringUUID, nullable=False)
+ # version of the workflow provider
+ version = db.Column(db.String(255), nullable=False, server_default='')
+ # who created this tool
+ user_id = db.Column(StringUUID, nullable=False)
+ # tenant id
+ tenant_id = db.Column(StringUUID, nullable=False)
+ # description of the provider
+ description = db.Column(db.Text, nullable=False)
+ # parameter configuration
+ parameter_configuration = db.Column(db.Text, nullable=False, server_default='[]')
+ # privacy policy
+ privacy_policy = db.Column(db.String(255), nullable=True, server_default='')
+
+ created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+ updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+
+ @property
+ def schema_type(self) -> ApiProviderSchemaType:
+ return ApiProviderSchemaType.value_of(self.schema_type_str)
+ @property
+ def user(self) -> Account:
+ return db.session.query(Account).filter(Account.id == self.user_id).first()
+
+ @property
+ def tenant(self) -> Tenant:
+ return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
+
+ @property
+ def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
+ return [
+ WorkflowToolParameterConfiguration(**config)
+ for config in json.loads(self.parameter_configuration)
+ ]
+
+ @property
+ def app(self) -> App:
+ return db.session.query(App).filter(App.id == self.app_id).first()
+
class ToolModelInvoke(db.Model):
"""
store the invoke logs from tool invoke
diff --git a/api/models/workflow.py b/api/models/workflow.py
index f261c67c77..d9bc784878 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -2,8 +2,8 @@ import json
from enum import Enum
from typing import Optional, Union
-from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
+from libs import helper
from models import StringUUID
from models.account import Account
@@ -156,6 +156,27 @@ class Workflow(db.Model):
return variables
+ @property
+ def unique_hash(self) -> str:
+ """
+ Get hash of workflow.
+
+ :return: hash
+ """
+ entity = {
+ 'graph': self.graph_dict,
+ 'features': self.features_dict
+ }
+
+ return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
+
+ @property
+ def tool_published(self) -> bool:
+ from models.tools import WorkflowToolProvider
+ return db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.app_id == self.app_id
+ ).first() is not None
+
class WorkflowRunTriggeredFrom(Enum):
"""
Workflow Run Triggered From Enum
@@ -242,6 +263,7 @@ class WorkflowRun(db.Model):
__table_args__ = (
db.PrimaryKeyConstraint('id', name='workflow_run_pkey'),
db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'),
+ db.Index('workflow_run_tenant_app_sequence_idx', 'tenant_id', 'app_id', 'sequence_number'),
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
@@ -457,6 +479,7 @@ class WorkflowNodeExecution(db.Model):
@property
def extras(self):
+ from core.tools.tool_manager import ToolManager
extras = {}
if self.execution_metadata_dict:
from core.workflow.entities.node_entities import NodeType
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 1e9f53cdb3..a8920139c6 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -7,12 +7,24 @@ exclude = [
line-length = 120
[tool.ruff.lint]
-ignore-init-module-imports = true
select = [
+ "B", # flake8-bugbear rules
"F", # pyflakes rules
- "I001", # unsorted-imports
- "I002", # missing-required-import
+ "I", # isort rules
"UP", # pyupgrade rules
+ "E101", # mixed-spaces-and-tabs
+ "E111", # indentation-with-invalid-multiple
+ "E112", # no-indented-block
+ "E113", # unexpected-indentation
+ "E115", # no-indented-block-comment
+ "E116", # unexpected-indentation-comment
+ "E117", # over-indented
+ "RUF019", # unnecessary-key-check
+ "RUF100", # unused-noqa
+ "RUF101", # redirected-noqa
+ "S506", # unsafe-yaml-load
+ "W191", # tab-indentation
+ "W605", # invalid-escape-sequence
]
ignore = [
"F403", # undefined-local-with-import-star
@@ -21,6 +33,13 @@ ignore = [
"F841", # unused-variable
"UP007", # non-pep604-annotation
"UP032", # f-string
+ "B005", # strip-with-multi-characters
+ "B006", # mutable-argument-default
+ "B007", # unused-loop-control-variable
+ "B026", # star-arg-unpacking-after-keyword-arg
+ "B901", # return-in-generator
+ "B904", # raise-without-from-inside-except
+ "B905", # zip-without-explicit-strict
]
[tool.ruff.lint.per-file-ignores]
diff --git a/api/requirements.txt b/api/requirements.txt
index 9d79afa4ec..1749b4a2df 100644
--- a/api/requirements.txt
+++ b/api/requirements.txt
@@ -9,8 +9,8 @@ flask-restful~=0.3.10
flask-cors~=4.0.0
gunicorn~=22.0.0
gevent~=23.9.1
-openai~=1.13.3
-tiktoken~=0.6.0
+openai~=1.29.0
+tiktoken~=0.7.0
psycopg2-binary~=2.9.6
pycryptodome==3.19.1
python-dotenv==1.0.0
@@ -26,7 +26,6 @@ sympy==1.12
jieba==0.42.1
celery~=5.3.6
redis[hiredis]~=5.0.3
-openpyxl==3.1.2
chardet~=5.1.0
python-docx~=1.1.0
pypdfium2~=4.17.0
@@ -42,7 +41,6 @@ google-api-python-client==2.90.0
google-auth==2.29.0
google-auth-httplib2==0.2.0
google-generativeai==0.5.0
-google-search-results==2.4.2
googleapis-common-protos==1.63.0
google-cloud-storage==2.16.0
replicate~=0.22.0
@@ -51,7 +49,7 @@ dashscope[tokenizer]~=1.17.0
huggingface_hub~=0.16.4
transformers~=4.35.0
tokenizers~=0.15.0
-pandas==1.5.3
+pandas[performance,excel]~=2.2.2
xinference-client==0.9.4
safetensors~=0.4.3
zhipuai==1.0.7
@@ -66,20 +64,24 @@ bs4~=0.0.1
markdown~=3.5.1
httpx[socks]~=0.24.1
matplotlib~=3.8.2
-yfinance~=0.2.35
+yfinance~=0.2.40
pydub~=0.25.1
gmpy2~=2.1.5
numexpr~=2.9.0
-duckduckgo-search==5.2.2
+duckduckgo-search~=6.1.5
arxiv==2.1.0
yarl~=1.9.4
twilio~=9.0.4
qrcode~=7.4.2
-azure-storage-blob==12.9.0
+azure-storage-blob==12.13.0
azure-identity==1.15.0
lxml==5.1.0
-xlrd~=2.0.1
pydantic~=1.10.0
pgvecto-rs==0.1.4
firecrawl-py==0.0.5
-oss2==2.15.0
+oss2==2.18.5
+pgvector==0.2.5
+pymysql==1.1.1
+tidb-vector==0.0.9
+google-cloud-aiplatform==1.49.0
+vanna[postgres,mysql,clickhouse,duckdb]==0.5.5
diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py
index 8b00b28c4f..addcde44ed 100644
--- a/api/services/annotation_service.py
+++ b/api/services/annotation_service.py
@@ -31,7 +31,7 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
- if 'message_id' in args and args['message_id']:
+ if args.get('message_id'):
message_id = str(args['message_id'])
# get message info
message = db.session.query(Message).filter(
diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py
index 185d9ba89f..f73a6dcbb6 100644
--- a/api/services/app_generate_service.py
+++ b/api/services/app_generate_service.py
@@ -75,6 +75,35 @@ class AppGenerateService:
else:
raise ValueError(f'Invalid app mode {app_model.mode}')
+ @classmethod
+ def generate_single_iteration(cls, app_model: App,
+ user: Union[Account, EndUser],
+ node_id: str,
+ args: Any,
+ streaming: bool = True):
+ if app_model.mode == AppMode.ADVANCED_CHAT.value:
+ workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
+ return AdvancedChatAppGenerator().single_iteration_generate(
+ app_model=app_model,
+ workflow=workflow,
+ node_id=node_id,
+ user=user,
+ args=args,
+ stream=streaming
+ )
+ elif app_model.mode == AppMode.WORKFLOW.value:
+ workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
+ return WorkflowAppGenerator().single_iteration_generate(
+ app_model=app_model,
+ workflow=workflow,
+ node_id=node_id,
+ user=user,
+ args=args,
+ stream=streaming
+ )
+ else:
+ raise ValueError(f'Invalid app mode {app_model.mode}')
+
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
diff --git a/api/services/app_service.py b/api/services/app_service.py
index 11073af09e..23c00740c8 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -47,10 +47,10 @@ class AppService:
elif args['mode'] == 'channel':
filters.append(App.mode == AppMode.CHANNEL.value)
- if 'name' in args and args['name']:
+ if args.get('name'):
name = args['name'][:30]
filters.append(App.name.ilike(f'%{name}%'))
- if 'tag_ids' in args and args['tag_ids']:
+ if args.get('tag_ids'):
target_ids = TagService.get_target_ids_by_tag_ids('app',
tenant_id,
args['tag_ids'])
@@ -196,6 +196,7 @@ class AppService:
app_model=app,
graph=workflow.get('graph'),
features=workflow.get('features'),
+ unique_hash=None,
account=account
)
workflow_service.publish_workflow(
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index 17399c8ac8..06d3e9ec40 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -4,7 +4,7 @@ import logging
import random
import time
import uuid
-from typing import Optional, cast
+from typing import Optional
from flask import current_app
from flask_login import current_user
@@ -13,7 +13,6 @@ from sqlalchemy import func
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.models.document import Document as RAGDocument
from events.dataset_event import dataset_was_deleted
@@ -43,6 +42,7 @@ from services.vector_service import VectorService
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
+from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
@@ -450,6 +450,27 @@ class DocumentService:
db.session.delete(document)
db.session.commit()
+ @staticmethod
+ def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
+ dataset = DatasetService.get_dataset(dataset_id)
+ if not dataset:
+ raise ValueError('Dataset not found.')
+
+ document = DocumentService.get_document(dataset_id, document_id)
+
+ if not document:
+ raise ValueError('Document not found.')
+
+ if document.tenant_id != current_user.current_tenant_id:
+ raise ValueError('No permission.')
+
+ document.name = name
+
+ db.session.add(document)
+ db.session.commit()
+
+ return document
+
@staticmethod
def pause_document(document):
if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]:
@@ -568,7 +589,7 @@ class DocumentService:
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
- if 'original_document_id' in document_data and document_data["original_document_id"]:
+ if document_data.get("original_document_id"):
document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)
documents.append(document)
else:
@@ -749,10 +770,10 @@ class DocumentService:
if document.display_status != 'available':
raise ValueError("Document is not available")
# update document name
- if 'name' in document_data and document_data['name']:
+ if document_data.get('name'):
document.name = document_data['name']
# save process rule
- if 'process_rule' in document_data and document_data['process_rule']:
+ if document_data.get('process_rule'):
process_rule = document_data["process_rule"]
if process_rule["mode"] == "custom":
dataset_process_rule = DatasetProcessRule(
@@ -772,7 +793,7 @@ class DocumentService:
db.session.commit()
document.dataset_process_rule_id = dataset_process_rule.id
# update document data source
- if 'data_source' in document_data and document_data['data_source']:
+ if document_data.get('data_source'):
file_name = ''
data_source_info = {}
if document_data["data_source"]["type"] == "upload_file":
@@ -870,7 +891,7 @@ class DocumentService:
embedding_model.model
)
dataset_collection_binding_id = dataset_collection_binding.id
- if 'retrieval_model' in document_data and document_data['retrieval_model']:
+ if document_data.get('retrieval_model'):
retrieval_model = document_data['retrieval_model']
else:
default_retrieval_model = {
@@ -920,9 +941,9 @@ class DocumentService:
and ('process_rule' not in args and not args['process_rule']):
raise ValueError("Data source or Process rule is required")
else:
- if 'data_source' in args and args['data_source']:
+ if args.get('data_source'):
DocumentService.data_source_args_validate(args)
- if 'process_rule' in args and args['process_rule']:
+ if args.get('process_rule'):
DocumentService.process_rule_args_validate(args)
@classmethod
@@ -1122,10 +1143,7 @@ class SegmentService:
model=dataset.embedding_model
)
# calc embedding use tokens
- model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
- tokens = model_type_instance.get_num_tokens(
- model=embedding_model.model,
- credentials=embedding_model.credentials,
+ tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content]
)
lock_name = 'add_segment_lock_document_id_{}'.format(document.id)
@@ -1193,10 +1211,7 @@ class SegmentService:
tokens = 0
if dataset.indexing_technique == 'high_quality' and embedding_model:
# calc embedding use tokens
- model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
- tokens = model_type_instance.get_num_tokens(
- model=embedding_model.model,
- credentials=embedding_model.credentials,
+ tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content]
)
segment_document = DocumentSegment(
@@ -1241,15 +1256,35 @@ class SegmentService:
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise ValueError("Segment is indexing, please try again later")
+ if 'enabled' in args and args['enabled'] is not None:
+ action = args['enabled']
+ if segment.enabled != action:
+ if not action:
+ segment.enabled = action
+ segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+ segment.disabled_by = current_user.id
+ db.session.add(segment)
+ db.session.commit()
+ # Set cache to prevent indexing the same segment multiple times
+ redis_client.setex(indexing_cache_key, 600, 1)
+ disable_segment_from_index_task.delay(segment.id)
+ return segment
+ if not segment.enabled:
+ if 'enabled' in args and args['enabled'] is not None:
+ if not args['enabled']:
+ raise ValueError("Can't update disabled segment")
+ else:
+ raise ValueError("Can't update disabled segment")
try:
content = args['content']
if segment.content == content:
if document.doc_form == 'qa_model':
segment.answer = args['answer']
- if 'keywords' in args and args['keywords']:
+ if args.get('keywords'):
segment.keywords = args['keywords']
- if 'enabled' in args and args['enabled'] is not None:
- segment.enabled = args['enabled']
+ segment.enabled = True
+ segment.disabled_at = None
+ segment.disabled_by = None
db.session.add(segment)
db.session.commit()
# update segment index task
@@ -1279,10 +1314,7 @@ class SegmentService:
)
# calc embedding use tokens
- model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
- tokens = model_type_instance.get_num_tokens(
- model=embedding_model.model,
- credentials=embedding_model.credentials,
+ tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content]
)
segment.content = content
@@ -1294,12 +1326,16 @@ class SegmentService:
segment.completed_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
segment.updated_by = current_user.id
segment.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+ segment.enabled = True
+ segment.disabled_at = None
+ segment.disabled_by = None
if document.doc_form == 'qa_model':
segment.answer = args['answer']
db.session.add(segment)
db.session.commit()
# update segment vector index
VectorService.update_segment_vector(args['keywords'], segment, dataset)
+
except Exception as e:
logging.exception("update segment index failed")
segment.enabled = False
diff --git a/api/services/enterprise/enterprise_feature_service.py b/api/services/enterprise/enterprise_feature_service.py
deleted file mode 100644
index fe33349aa8..0000000000
--- a/api/services/enterprise/enterprise_feature_service.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from flask import current_app
-from pydantic import BaseModel
-
-from services.enterprise.enterprise_service import EnterpriseService
-
-
-class EnterpriseFeatureModel(BaseModel):
- sso_enforced_for_signin: bool = False
- sso_enforced_for_signin_protocol: str = ''
-
-
-class EnterpriseFeatureService:
-
- @classmethod
- def get_enterprise_features(cls) -> EnterpriseFeatureModel:
- features = EnterpriseFeatureModel()
-
- if current_app.config['ENTERPRISE_ENABLED']:
- cls._fulfill_params_from_enterprise(features)
-
- return features
-
- @classmethod
- def _fulfill_params_from_enterprise(cls, features):
- enterprise_info = EnterpriseService.get_info()
-
- features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
- features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']
diff --git a/api/services/enterprise/enterprise_sso_service.py b/api/services/enterprise/enterprise_sso_service.py
deleted file mode 100644
index d8e19f23bf..0000000000
--- a/api/services/enterprise/enterprise_sso_service.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import logging
-
-from models.account import Account, AccountStatus
-from services.account_service import AccountService, TenantService
-from services.enterprise.base import EnterpriseRequest
-
-logger = logging.getLogger(__name__)
-
-
-class EnterpriseSSOService:
-
- @classmethod
- def get_sso_saml_login(cls) -> str:
- return EnterpriseRequest.send_request('GET', '/sso/saml/login')
-
- @classmethod
- def post_sso_saml_acs(cls, saml_response: str) -> str:
- response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response})
- if 'email' not in response or response['email'] is None:
- logger.exception(response)
- raise Exception('Saml response is invalid')
-
- return cls.login_with_email(response.get('email'))
-
- @classmethod
- def get_sso_oidc_login(cls):
- return EnterpriseRequest.send_request('GET', '/sso/oidc/login')
-
- @classmethod
- def get_sso_oidc_callback(cls, args: dict):
- state_from_query = args['state']
- code_from_query = args['code']
- state_from_cookies = args['oidc-state']
-
- if state_from_cookies != state_from_query:
- raise Exception('invalid state or code')
-
- response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query})
- if 'email' not in response or response['email'] is None:
- logger.exception(response)
- raise Exception('OIDC response is invalid')
-
- return cls.login_with_email(response.get('email'))
-
- @classmethod
- def login_with_email(cls, email: str) -> str:
- account = Account.query.filter_by(email=email).first()
- if account is None:
- raise Exception('account not found, please contact system admin to invite you to join in a workspace')
-
- if account.status == AccountStatus.BANNED:
- raise Exception('account is banned, please contact system admin')
-
- tenants = TenantService.get_join_tenants(account)
- if len(tenants) == 0:
- raise Exception("workspace not found, please contact system admin to invite you to join in a workspace")
-
- token = AccountService.get_account_jwt_token(account)
-
- return token
diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py
index 6cdd5090ae..77bb5e08c3 100644
--- a/api/services/entities/model_provider_entities.py
+++ b/api/services/entities/model_provider_entities.py
@@ -4,10 +4,10 @@ from typing import Optional
from flask import current_app
from pydantic import BaseModel
-from core.entities.model_entities import ModelStatus, ModelWithProviderEntity
+from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.entities.provider_entities import QuotaConfiguration
from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.model_entities import ModelType, ProviderModel
+from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
ModelCredentialSchema,
@@ -79,13 +79,6 @@ class ProviderResponse(BaseModel):
)
-class ModelResponse(ProviderModel):
- """
- Model class for model response.
- """
- status: ModelStatus
-
-
class ProviderWithModelsResponse(BaseModel):
"""
Model class for provider with models response.
@@ -95,7 +88,7 @@ class ProviderWithModelsResponse(BaseModel):
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
status: CustomConfigurationStatus
- models: list[ModelResponse]
+ models: list[ProviderModelWithStatusEntity]
def __init__(self, **data) -> None:
super().__init__(**data)
diff --git a/api/services/errors/app.py b/api/services/errors/app.py
index 7c4ca99c2a..87e9e9247d 100644
--- a/api/services/errors/app.py
+++ b/api/services/errors/app.py
@@ -1,2 +1,6 @@
class MoreLikeThisDisabledError(Exception):
pass
+
+
+class WorkflowHashNotEqualError(Exception):
+ pass
diff --git a/api/services/feature_service.py b/api/services/feature_service.py
index 3cf51d11a0..36cbc3902b 100644
--- a/api/services/feature_service.py
+++ b/api/services/feature_service.py
@@ -2,6 +2,7 @@ from flask import current_app
from pydantic import BaseModel
from services.billing_service import BillingService
+from services.enterprise.enterprise_service import EnterpriseService
class SubscriptionModel(BaseModel):
@@ -28,6 +29,14 @@ class FeatureModel(BaseModel):
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
docs_processing: str = 'standard'
can_replace_logo: bool = False
+ model_load_balancing_enabled: bool = False
+
+
+class SystemFeatureModel(BaseModel):
+ sso_enforced_for_signin: bool = False
+ sso_enforced_for_signin_protocol: str = ''
+ sso_enforced_for_web: bool = False
+ sso_enforced_for_web_protocol: str = ''
class FeatureService:
@@ -43,9 +52,19 @@ class FeatureService:
return features
+ @classmethod
+ def get_system_features(cls) -> SystemFeatureModel:
+ system_features = SystemFeatureModel()
+
+ if current_app.config['ENTERPRISE_ENABLED']:
+ cls._fulfill_params_from_enterprise(system_features)
+
+ return system_features
+
@classmethod
def _fulfill_params_from_env(cls, features: FeatureModel):
features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO']
+ features.model_load_balancing_enabled = current_app.config['MODEL_LB_ENABLED']
@classmethod
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
@@ -55,21 +74,40 @@ class FeatureService:
features.billing.subscription.plan = billing_info['subscription']['plan']
features.billing.subscription.interval = billing_info['subscription']['interval']
- features.members.size = billing_info['members']['size']
- features.members.limit = billing_info['members']['limit']
+ if 'members' in billing_info:
+ features.members.size = billing_info['members']['size']
+ features.members.limit = billing_info['members']['limit']
- features.apps.size = billing_info['apps']['size']
- features.apps.limit = billing_info['apps']['limit']
+ if 'apps' in billing_info:
+ features.apps.size = billing_info['apps']['size']
+ features.apps.limit = billing_info['apps']['limit']
- features.vector_space.size = billing_info['vector_space']['size']
- features.vector_space.limit = billing_info['vector_space']['limit']
+ if 'vector_space' in billing_info:
+ features.vector_space.size = billing_info['vector_space']['size']
+ features.vector_space.limit = billing_info['vector_space']['limit']
- features.documents_upload_quota.size = billing_info['documents_upload_quota']['size']
- features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit']
+ if 'documents_upload_quota' in billing_info:
+ features.documents_upload_quota.size = billing_info['documents_upload_quota']['size']
+ features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit']
- features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size']
- features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit']
+ if 'annotation_quota_limit' in billing_info:
+ features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size']
+ features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit']
- features.docs_processing = billing_info['docs_processing']
- features.can_replace_logo = billing_info['can_replace_logo']
+ if 'docs_processing' in billing_info:
+ features.docs_processing = billing_info['docs_processing']
+ if 'can_replace_logo' in billing_info:
+ features.can_replace_logo = billing_info['can_replace_logo']
+
+ if 'model_load_balancing_enabled' in billing_info:
+ features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled']
+
+ @classmethod
+ def _fulfill_params_from_enterprise(cls, features):
+ enterprise_info = EnterpriseService.get_info()
+
+ features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
+ features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']
+ features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web']
+ features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol']
diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py
new file mode 100644
index 0000000000..c684c2862b
--- /dev/null
+++ b/api/services/model_load_balancing_service.py
@@ -0,0 +1,565 @@
+import datetime
+import json
+import logging
+from json import JSONDecodeError
+from typing import Optional
+
+from core.entities.provider_configuration import ProviderConfiguration
+from core.helper import encrypter
+from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
+from core.model_manager import LBModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.entities.provider_entities import (
+ ModelCredentialSchema,
+ ProviderCredentialSchema,
+)
+from core.model_runtime.model_providers import model_provider_factory
+from core.provider_manager import ProviderManager
+from extensions.ext_database import db
+from models.provider import LoadBalancingModelConfig
+
+logger = logging.getLogger(__name__)
+
+
+class ModelLoadBalancingService:
+
+ def __init__(self) -> None:
+ self.provider_manager = ProviderManager()
+
+ def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
+ """
+ enable model load balancing.
+
+ :param tenant_id: workspace id
+ :param provider: provider name
+ :param model: model name
+ :param model_type: model type
+ :return:
+ """
+ # Get all provider configurations of the current workspace
+ provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+ # Get provider configuration
+ provider_configuration = provider_configurations.get(provider)
+ if not provider_configuration:
+ raise ValueError(f"Provider {provider} does not exist.")
+
+ # Enable model load balancing
+ provider_configuration.enable_model_load_balancing(
+ model=model,
+ model_type=ModelType.value_of(model_type)
+ )
+
+ def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
+ """
+ disable model load balancing.
+
+ :param tenant_id: workspace id
+ :param provider: provider name
+ :param model: model name
+ :param model_type: model type
+ :return:
+ """
+ # Get all provider configurations of the current workspace
+ provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+ # Get provider configuration
+ provider_configuration = provider_configurations.get(provider)
+ if not provider_configuration:
+ raise ValueError(f"Provider {provider} does not exist.")
+
+ # disable model load balancing
+ provider_configuration.disable_model_load_balancing(
+ model=model,
+ model_type=ModelType.value_of(model_type)
+ )
+
+ def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \
+ -> tuple[bool, list[dict]]:
+ """
+ Get load balancing configurations.
+ :param tenant_id: workspace id
+ :param provider: provider name
+ :param model: model name
+ :param model_type: model type
+ :return:
+ """
+ # Get all provider configurations of the current workspace
+ provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+ # Get provider configuration
+ provider_configuration = provider_configurations.get(provider)
+ if not provider_configuration:
+ raise ValueError(f"Provider {provider} does not exist.")
+
+ # Convert model type to ModelType
+ model_type = ModelType.value_of(model_type)
+
+ # Get provider model setting
+ provider_model_setting = provider_configuration.get_provider_model_setting(
+ model_type=model_type,
+ model=model,
+ )
+
+ is_load_balancing_enabled = False
+ if provider_model_setting and provider_model_setting.load_balancing_enabled:
+ is_load_balancing_enabled = True
+
+ # Get load balancing configurations
+ load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
+ .filter(
+ LoadBalancingModelConfig.tenant_id == tenant_id,
+ LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
+ LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+ LoadBalancingModelConfig.model_name == model
+ ).order_by(LoadBalancingModelConfig.created_at).all()
+
+ if provider_configuration.custom_configuration.provider:
+ # check if the inherit configuration exists,
+ # inherit is represented for the provider or model custom credentials
+ inherit_config_exists = False
+ for load_balancing_config in load_balancing_configs:
+ if load_balancing_config.name == '__inherit__':
+ inherit_config_exists = True
+ break
+
+ if not inherit_config_exists:
+ # Initialize the inherit configuration
+ inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type)
+
+ # prepend the inherit configuration
+ load_balancing_configs.insert(0, inherit_config)
+ else:
+ # move the inherit configuration to the first
+ for i, load_balancing_config in enumerate(load_balancing_configs):
+ if load_balancing_config.name == '__inherit__':
+ inherit_config = load_balancing_configs.pop(i)
+ load_balancing_configs.insert(0, inherit_config)
+
+ # Get credential form schemas from model credential schema or provider credential schema
+ credential_schemas = self._get_credential_schema(provider_configuration)
+
+ # Get decoding rsa key and cipher for decrypting credentials
+ decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
+
+ # fetch status and ttl for each config
+ datas = []
+ for load_balancing_config in load_balancing_configs:
+ in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl(
+ tenant_id=tenant_id,
+ provider=provider,
+ model=model,
+ model_type=model_type,
+ config_id=load_balancing_config.id
+ )
+
+ try:
+ if load_balancing_config.encrypted_config:
+ credentials = json.loads(load_balancing_config.encrypted_config)
+ else:
+ credentials = {}
+ except JSONDecodeError:
+ credentials = {}
+
+ # Get provider credential secret variables
+ credential_secret_variables = provider_configuration.extract_secret_variables(
+ credential_schemas.credential_form_schemas
+ )
+
+ # decrypt credentials
+ for variable in credential_secret_variables:
+ if variable in credentials:
+ try:
+ credentials[variable] = encrypter.decrypt_token_with_decoding(
+ credentials.get(variable),
+ decoding_rsa_key,
+ decoding_cipher_rsa
+ )
+ except ValueError:
+ pass
+
+ # Obfuscate credentials
+ credentials = provider_configuration.obfuscated_credentials(
+ credentials=credentials,
+ credential_form_schemas=credential_schemas.credential_form_schemas
+ )
+
+ datas.append({
+ 'id': load_balancing_config.id,
+ 'name': load_balancing_config.name,
+ 'credentials': credentials,
+ 'enabled': load_balancing_config.enabled,
+ 'in_cooldown': in_cooldown,
+ 'ttl': ttl
+ })
+
+ return is_load_balancing_enabled, datas
+
+ def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \
+ -> Optional[dict]:
+ """
+ Get load balancing configuration.
+ :param tenant_id: workspace id
+ :param provider: provider name
+ :param model: model name
+ :param model_type: model type
+ :param config_id: load balancing config id
+ :return:
+ """
+ # Get all provider configurations of the current workspace
+ provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+ # Get provider configuration
+ provider_configuration = provider_configurations.get(provider)
+ if not provider_configuration:
+ raise ValueError(f"Provider {provider} does not exist.")
+
+ # Convert model type to ModelType
+ model_type = ModelType.value_of(model_type)
+
+ # Get load balancing configurations
+ load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
+ .filter(
+ LoadBalancingModelConfig.tenant_id == tenant_id,
+ LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
+ LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+ LoadBalancingModelConfig.model_name == model,
+ LoadBalancingModelConfig.id == config_id
+ ).first()
+
+ if not load_balancing_model_config:
+ return None
+
+ try:
+ if load_balancing_model_config.encrypted_config:
+ credentials = json.loads(load_balancing_model_config.encrypted_config)
+ else:
+ credentials = {}
+ except JSONDecodeError:
+ credentials = {}
+
+ # Get credential form schemas from model credential schema or provider credential schema
+ credential_schemas = self._get_credential_schema(provider_configuration)
+
+ # Obfuscate credentials
+ credentials = provider_configuration.obfuscated_credentials(
+ credentials=credentials,
+ credential_form_schemas=credential_schemas.credential_form_schemas
+ )
+
+ return {
+ 'id': load_balancing_model_config.id,
+ 'name': load_balancing_model_config.name,
+ 'credentials': credentials,
+ 'enabled': load_balancing_model_config.enabled
+ }
+
+ def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \
+ -> LoadBalancingModelConfig:
+ """
+ Initialize the inherit configuration.
+ :param tenant_id: workspace id
+ :param provider: provider name
+ :param model: model name
+ :param model_type: model type
+ :return:
+ """
+ # Initialize the inherit configuration
+ inherit_config = LoadBalancingModelConfig(
+ tenant_id=tenant_id,
+ provider_name=provider,
+ model_type=model_type.to_origin_model_type(),
+ model_name=model,
+ name='__inherit__'
+ )
+ db.session.add(inherit_config)
+ db.session.commit()
+
+ return inherit_config
+
+ def update_load_balancing_configs(self, tenant_id: str,
+ provider: str,
+ model: str,
+ model_type: str,
+ configs: list[dict]) -> None:
+ """
+ Update load balancing configurations.
+ :param tenant_id: workspace id
+ :param provider: provider name
+ :param model: model name
+ :param model_type: model type
+ :param configs: load balancing configs
+ :return:
+ """
+ # Get all provider configurations of the current workspace
+ provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+ # Get provider configuration
+ provider_configuration = provider_configurations.get(provider)
+ if not provider_configuration:
+ raise ValueError(f"Provider {provider} does not exist.")
+
+ # Convert model type to ModelType
+ model_type = ModelType.value_of(model_type)
+
+ if not isinstance(configs, list):
+ raise ValueError('Invalid load balancing configs')
+
+ current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
+ .filter(
+ LoadBalancingModelConfig.tenant_id == tenant_id,
+ LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
+ LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+ LoadBalancingModelConfig.model_name == model
+ ).all()
+
+ # id as key, config as value
+ current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
+ updated_config_ids = set()
+
+ for config in configs:
+ if not isinstance(config, dict):
+ raise ValueError('Invalid load balancing config')
+
+ config_id = config.get('id')
+ name = config.get('name')
+ credentials = config.get('credentials')
+ enabled = config.get('enabled')
+
+ if not name:
+ raise ValueError('Invalid load balancing config name')
+
+ if enabled is None:
+ raise ValueError('Invalid load balancing config enabled')
+
+ # is config exists
+ if config_id:
+ config_id = str(config_id)
+
+ if config_id not in current_load_balancing_configs_dict:
+ raise ValueError('Invalid load balancing config id: {}'.format(config_id))
+
+ updated_config_ids.add(config_id)
+
+ load_balancing_config = current_load_balancing_configs_dict[config_id]
+
+ # check duplicate name
+ for current_load_balancing_config in current_load_balancing_configs:
+ if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
+ raise ValueError('Load balancing config name {} already exists'.format(name))
+
+ if credentials:
+ if not isinstance(credentials, dict):
+ raise ValueError('Invalid load balancing config credentials')
+
+ # validate custom provider config
+ credentials = self._custom_credentials_validate(
+ tenant_id=tenant_id,
+ provider_configuration=provider_configuration,
+ model_type=model_type,
+ model=model,
+ credentials=credentials,
+ load_balancing_model_config=load_balancing_config,
+ validate=False
+ )
+
+ # update load balancing config
+ load_balancing_config.encrypted_config = json.dumps(credentials)
+
+ load_balancing_config.name = name
+ load_balancing_config.enabled = enabled
+ load_balancing_config.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+ db.session.commit()
+
+ self._clear_credentials_cache(tenant_id, config_id)
+ else:
+ # create load balancing config
+ if name == '__inherit__':
+ raise ValueError('Invalid load balancing config name')
+
+ # check duplicate name
+ for current_load_balancing_config in current_load_balancing_configs:
+ if current_load_balancing_config.name == name:
+ raise ValueError('Load balancing config name {} already exists'.format(name))
+
+ if not credentials:
+ raise ValueError('Invalid load balancing config credentials')
+
+ if not isinstance(credentials, dict):
+ raise ValueError('Invalid load balancing config credentials')
+
+ # validate custom provider config
+ credentials = self._custom_credentials_validate(
+ tenant_id=tenant_id,
+ provider_configuration=provider_configuration,
+ model_type=model_type,
+ model=model,
+ credentials=credentials,
+ validate=False
+ )
+
+ # create load balancing config
+ load_balancing_model_config = LoadBalancingModelConfig(
+ tenant_id=tenant_id,
+ provider_name=provider_configuration.provider.provider,
+ model_type=model_type.to_origin_model_type(),
+ model_name=model,
+ name=name,
+ encrypted_config=json.dumps(credentials)
+ )
+
+ db.session.add(load_balancing_model_config)
+ db.session.commit()
+
+ # get deleted config ids
+ deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids
+ for config_id in deleted_config_ids:
+ db.session.delete(current_load_balancing_configs_dict[config_id])
+ db.session.commit()
+
+ self._clear_credentials_cache(tenant_id, config_id)
+
+ def validate_load_balancing_credentials(self, tenant_id: str,
+ provider: str,
+ model: str,
+ model_type: str,
+ credentials: dict,
+ config_id: Optional[str] = None) -> None:
+ """
+ Validate load balancing credentials.
+ :param tenant_id: workspace id
+ :param provider: provider name
+ :param model_type: model type
+ :param model: model name
+ :param credentials: credentials
+ :param config_id: load balancing config id
+ :return:
+ """
+ # Get all provider configurations of the current workspace
+ provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+ # Get provider configuration
+ provider_configuration = provider_configurations.get(provider)
+ if not provider_configuration:
+ raise ValueError(f"Provider {provider} does not exist.")
+
+ # Convert model type to ModelType
+ model_type = ModelType.value_of(model_type)
+
+ load_balancing_model_config = None
+ if config_id:
+ # Get load balancing config
+ load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
+ .filter(
+ LoadBalancingModelConfig.tenant_id == tenant_id,
+ LoadBalancingModelConfig.provider_name == provider,
+ LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+ LoadBalancingModelConfig.model_name == model,
+ LoadBalancingModelConfig.id == config_id
+ ).first()
+
+ if not load_balancing_model_config:
+ raise ValueError(f"Load balancing config {config_id} does not exist.")
+
+ # Validate custom provider config
+ self._custom_credentials_validate(
+ tenant_id=tenant_id,
+ provider_configuration=provider_configuration,
+ model_type=model_type,
+ model=model,
+ credentials=credentials,
+ load_balancing_model_config=load_balancing_model_config
+ )
+
+ def _custom_credentials_validate(self, tenant_id: str,
+ provider_configuration: ProviderConfiguration,
+ model_type: ModelType,
+ model: str,
+ credentials: dict,
+ load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
+ validate: bool = True) -> dict:
+ """
+ Validate custom credentials.
+ :param tenant_id: workspace id
+ :param provider_configuration: provider configuration
+ :param model_type: model type
+ :param model: model name
+ :param credentials: credentials
+ :param load_balancing_model_config: load balancing model config
+ :param validate: validate credentials
+ :return:
+ """
+ # Get credential form schemas from model credential schema or provider credential schema
+ credential_schemas = self._get_credential_schema(provider_configuration)
+
+ # Get provider credential secret variables
+ provider_credential_secret_variables = provider_configuration.extract_secret_variables(
+ credential_schemas.credential_form_schemas
+ )
+
+ if load_balancing_model_config:
+ try:
+ # fix origin data
+ if load_balancing_model_config.encrypted_config:
+ original_credentials = json.loads(load_balancing_model_config.encrypted_config)
+ else:
+ original_credentials = {}
+ except JSONDecodeError:
+ original_credentials = {}
+
+ # encrypt credentials
+ for key, value in credentials.items():
+ if key in provider_credential_secret_variables:
+ # if send [__HIDDEN__] in secret input, it will be same as original value
+ if value == '[__HIDDEN__]' and key in original_credentials:
+ credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])
+
+ if validate:
+ if isinstance(credential_schemas, ModelCredentialSchema):
+ credentials = model_provider_factory.model_credentials_validate(
+ provider=provider_configuration.provider.provider,
+ model_type=model_type,
+ model=model,
+ credentials=credentials
+ )
+ else:
+ credentials = model_provider_factory.provider_credentials_validate(
+ provider=provider_configuration.provider.provider,
+ credentials=credentials
+ )
+
+ for key, value in credentials.items():
+ if key in provider_credential_secret_variables:
+ credentials[key] = encrypter.encrypt_token(tenant_id, value)
+
+ return credentials
+
+ def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \
+ -> ModelCredentialSchema | ProviderCredentialSchema:
+ """
+ Get form schemas.
+ :param provider_configuration: provider configuration
+ :return:
+ """
+ # Get credential form schemas from model credential schema or provider credential schema
+ if provider_configuration.provider.model_credential_schema:
+ credential_schema = provider_configuration.provider.model_credential_schema
+ else:
+ credential_schema = provider_configuration.provider.provider_credential_schema
+
+ return credential_schema
+
+ def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
+ """
+ Clear credentials cache.
+ :param tenant_id: workspace id
+ :param config_id: load balancing config id
+ :return:
+ """
+ provider_model_credentials_cache = ProviderCredentialsCache(
+ tenant_id=tenant_id,
+ identity_id=config_id,
+ cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
+ )
+
+ provider_model_credentials_cache.delete()
diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py
index 5a4342ae03..385af685f9 100644
--- a/api/services/model_provider_service.py
+++ b/api/services/model_provider_service.py
@@ -6,7 +6,7 @@ from typing import Optional, cast
import requests
from flask import current_app
-from core.entities.model_entities import ModelStatus
+from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -16,7 +16,6 @@ from services.entities.model_provider_entities import (
CustomConfigurationResponse,
CustomConfigurationStatus,
DefaultModelResponse,
- ModelResponse,
ModelWithProviderEntityResponse,
ProviderResponse,
ProviderWithModelsResponse,
@@ -303,6 +302,9 @@ class ModelProviderService:
if model.deprecated:
continue
+ if model.status != ModelStatus.ACTIVE:
+ continue
+
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
@@ -313,24 +315,22 @@ class ModelProviderService:
first_model = models[0]
- has_active_models = any([model.status == ModelStatus.ACTIVE for model in models])
-
providers_with_models.append(
ProviderWithModelsResponse(
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large,
- status=CustomConfigurationStatus.ACTIVE
- if has_active_models else CustomConfigurationStatus.NO_CONFIGURE,
- models=[ModelResponse(
+ status=CustomConfigurationStatus.ACTIVE,
+ models=[ProviderModelWithStatusEntity(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
- status=model.status
+ status=model.status,
+ load_balancing_enabled=model.load_balancing_enabled
) for model in models]
)
)
@@ -486,6 +486,54 @@ class ModelProviderService:
# Switch preferred provider type
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
+ def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
+ """
+ enable model.
+
+ :param tenant_id: workspace id
+ :param provider: provider name
+ :param model: model name
+ :param model_type: model type
+ :return:
+ """
+ # Get all provider configurations of the current workspace
+ provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+ # Get provider configuration
+ provider_configuration = provider_configurations.get(provider)
+ if not provider_configuration:
+ raise ValueError(f"Provider {provider} does not exist.")
+
+ # Enable model
+ provider_configuration.enable_model(
+ model=model,
+ model_type=ModelType.value_of(model_type)
+ )
+
+ def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
+ """
+ disable model.
+
+ :param tenant_id: workspace id
+ :param provider: provider name
+ :param model: model name
+ :param model_type: model type
+ :return:
+ """
+ # Get all provider configurations of the current workspace
+ provider_configurations = self.provider_manager.get_configurations(tenant_id)
+
+ # Get provider configuration
+ provider_configuration = provider_configurations.get(provider)
+ if not provider_configuration:
+ raise ValueError(f"Provider {provider} does not exist.")
+
+ # Enable model
+ provider_configuration.disable_model(
+ model=model,
+ model_type=ModelType.value_of(model_type)
+ )
+
def free_quota_submit(self, tenant_id: str, provider: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py
index 3d36fb80af..6a155922b4 100644
--- a/api/services/recommended_app_service.py
+++ b/api/services/recommended_app_service.py
@@ -86,6 +86,7 @@ class RecommendedAppService:
'description': site.description,
'copyright': site.copyright,
'privacy_policy': site.privacy_policy,
+ 'custom_disclaimer': site.custom_disclaimer,
'category': recommended_app.category,
'position': recommended_app.position,
'is_listed': recommended_app.is_listed
@@ -94,7 +95,7 @@ class RecommendedAppService:
categories.add(recommended_app.category) # add category to categories
- return {'recommended_apps': recommended_apps_result, 'categories': list(categories)}
+ return {'recommended_apps': recommended_apps_result, 'categories': sorted(list(categories))}
@classmethod
def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
diff --git a/api/services/tools_manage_service.py b/api/services/tools/api_tools_manage_service.py
similarity index 57%
rename from api/services/tools_manage_service.py
rename to api/services/tools/api_tools_manage_service.py
index ec4e89bd14..9a0d6ca8d9 100644
--- a/api/services/tools_manage_service.py
+++ b/api/services/tools/api_tools_manage_service.py
@@ -4,97 +4,30 @@ import logging
from httpx import get
from core.model_runtime.utils.encoders import jsonable_encoder
+from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_bundle import ApiBasedToolBundle
+from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ApiProviderSchemaType,
ToolCredentialsOption,
ToolProviderCredentials,
)
-from core.tools.entities.user_entities import UserTool, UserToolProvider
-from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
-from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
-from core.tools.provider.builtin._positions import BuiltinToolProviderSort
-from core.tools.provider.tool_provider import ToolProviderController
+from core.tools.provider.api_tool_provider import ApiToolProviderController
+from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolConfigurationManager
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
-from models.tools import ApiToolProvider, BuiltinToolProvider
-from services.model_provider_service import ModelProviderService
-from services.tools_transform_service import ToolTransformService
+from models.tools import ApiToolProvider
+from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
-class ToolManageService:
+class ApiToolManageService:
@staticmethod
- def list_tool_providers(user_id: str, tenant_id: str):
- """
- list tool providers
-
- :return: the list of tool providers
- """
- providers = ToolManager.user_list_providers(
- user_id, tenant_id
- )
-
- # add icon
- for provider in providers:
- ToolTransformService.repack_provider(provider)
-
- result = [provider.to_dict() for provider in providers]
-
- return result
-
- @staticmethod
- def list_builtin_tool_provider_tools(
- user_id: str, tenant_id: str, provider: str
- ) -> list[UserTool]:
- """
- list builtin tool provider tools
- """
- provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
- tools = provider_controller.get_tools()
-
- tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
- # check if user has added the provider
- builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- BuiltinToolProvider.provider == provider,
- ).first()
-
- credentials = {}
- if builtin_provider is not None:
- # get credentials
- credentials = builtin_provider.credentials
- credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
-
- result = []
- for tool in tools:
- result.append(ToolTransformService.tool_to_user_tool(
- tool=tool, credentials=credentials, tenant_id=tenant_id
- ))
-
- return result
-
- @staticmethod
- def list_builtin_provider_credentials_schema(
- provider_name
- ):
- """
- list builtin provider credentials schema
-
- :return: the list of tool providers
- """
- provider = ToolManager.get_builtin_provider(provider_name)
- return jsonable_encoder([
- v for _, v in (provider.credentials_schema or {}).items()
- ])
-
- @staticmethod
- def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:
+ def parser_api_schema(schema: str) -> list[ApiToolBundle]:
"""
parse api schema to tool bundle
"""
@@ -162,7 +95,7 @@ class ToolManageService:
raise ValueError(f'invalid schema: {str(e)}')
@staticmethod
- def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiBasedToolBundle]:
+ def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
"""
convert schema to tool bundles
@@ -177,7 +110,7 @@ class ToolManageService:
@staticmethod
def create_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,
- schema_type: str, schema: str, privacy_policy: str
+ schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
):
"""
create api tool provider
@@ -197,7 +130,7 @@ class ToolManageService:
# parse openapi to tool bundle
extra_info = {}
# extra info like description will be set here
- tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
+ tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
if len(tool_bundles) > 100:
raise ValueError('the number of apis should be less than 100')
@@ -213,7 +146,8 @@ class ToolManageService:
schema_type_str=schema_type,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str={},
- privacy_policy=privacy_policy
+ privacy_policy=privacy_policy,
+ custom_disclaimer=custom_disclaimer
)
if 'auth_type' not in credentials:
@@ -223,7 +157,7 @@ class ToolManageService:
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
# create provider entity
- provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type)
+ provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
@@ -235,6 +169,9 @@ class ToolManageService:
db.session.add(db_provider)
db.session.commit()
+ # update labels
+ ToolLabelManager.update_tool_labels(provider_controller, labels)
+
return { 'result': 'success' }
@staticmethod
@@ -256,7 +193,7 @@ class ToolManageService:
schema = response.text
# try to parse schema, avoid SSRF attack
- ToolManageService.parser_api_schema(schema)
+ ApiToolManageService.parser_api_schema(schema)
except Exception as e:
logger.error(f"parse api schema error: {str(e)}")
raise ValueError('invalid schema, please check the url you provided')
@@ -280,91 +217,20 @@ class ToolManageService:
if provider is None:
raise ValueError(f'you have not added provider {provider}')
- return [
- ToolTransformService.tool_to_user_tool(tool_bundle) for tool_bundle in provider.tools
- ]
-
- @staticmethod
- def update_builtin_tool_provider(
- user_id: str, tenant_id: str, provider_name: str, credentials: dict
- ):
- """
- update builtin tool provider
- """
- # get if the provider exists
- provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- BuiltinToolProvider.provider == provider_name,
- ).first()
-
- try:
- # get provider
- provider_controller = ToolManager.get_builtin_provider(provider_name)
- if not provider_controller.need_credentials:
- raise ValueError(f'provider {provider_name} does not need credentials')
- tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
- # get original credentials if exists
- if provider is not None:
- original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
- masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
- # check if the credential has changed, save the original credential
- for name, value in credentials.items():
- if name in masked_credentials and value == masked_credentials[name]:
- credentials[name] = original_credentials[name]
- # validate credentials
- provider_controller.validate_credentials(credentials)
- # encrypt credentials
- credentials = tool_configuration.encrypt_tool_credentials(credentials)
- except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
- raise ValueError(str(e))
-
- if provider is None:
- # create provider
- provider = BuiltinToolProvider(
- tenant_id=tenant_id,
- user_id=user_id,
- provider=provider_name,
- encrypted_credentials=json.dumps(credentials),
- )
-
- db.session.add(provider)
- db.session.commit()
-
- else:
- provider.encrypted_credentials = json.dumps(credentials)
- db.session.add(provider)
- db.session.commit()
-
- # delete cache
- tool_configuration.delete_tool_credentials_cache()
-
- return { 'result': 'success' }
-
- @staticmethod
- def get_builtin_tool_provider_credentials(
- user_id: str, tenant_id: str, provider: str
- ):
- """
- get builtin tool provider credentials
- """
- provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- BuiltinToolProvider.provider == provider,
- ).first()
-
- if provider is None:
- return {}
+ controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
+ labels = ToolLabelManager.get_tool_labels(controller)
- provider_controller = ToolManager.get_builtin_provider(provider.provider)
- tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
- credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
- credentials = tool_configuration.mask_tool_credentials(credentials)
- return credentials
+ return [
+ ToolTransformService.tool_to_user_tool(
+ tool_bundle,
+ labels=labels,
+ ) for tool_bundle in provider.tools
+ ]
@staticmethod
def update_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
- schema_type: str, schema: str, privacy_policy: str
+ schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
):
"""
update api tool provider
@@ -384,7 +250,7 @@ class ToolManageService:
# parse openapi to tool bundle
extra_info = {}
# extra info like description will be set here
- tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
+ tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
# update db provider
provider.name = provider_name
@@ -394,6 +260,7 @@ class ToolManageService:
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
+ provider.custom_disclaimer = custom_disclaimer
if 'auth_type' not in credentials:
raise ValueError('auth_type is required')
@@ -402,7 +269,7 @@ class ToolManageService:
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
# create provider entity
- provider_controller = ApiBasedToolProviderController.from_db(provider, auth_type)
+ provider_controller = ApiToolProviderController.from_db(provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
@@ -425,84 +292,11 @@ class ToolManageService:
# delete cache
tool_configuration.delete_tool_credentials_cache()
- return { 'result': 'success' }
-
- @staticmethod
- def delete_builtin_tool_provider(
- user_id: str, tenant_id: str, provider_name: str
- ):
- """
- delete tool provider
- """
- provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- BuiltinToolProvider.provider == provider_name,
- ).first()
-
- if provider is None:
- raise ValueError(f'you have not added provider {provider_name}')
-
- db.session.delete(provider)
- db.session.commit()
-
- # delete cache
- provider_controller = ToolManager.get_builtin_provider(provider_name)
- tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
- tool_configuration.delete_tool_credentials_cache()
+ # update labels
+ ToolLabelManager.update_tool_labels(provider_controller, labels)
return { 'result': 'success' }
- @staticmethod
- def get_builtin_tool_provider_icon(
- provider: str
- ):
- """
- get tool provider icon and it's mimetype
- """
- icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
- with open(icon_path, 'rb') as f:
- icon_bytes = f.read()
-
- return icon_bytes, mime_type
-
- @staticmethod
- def get_model_tool_provider_icon(
- provider: str
- ):
- """
- get tool provider icon and it's mimetype
- """
-
- service = ModelProviderService()
- icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US')
-
- if icon_bytes is None:
- raise ValueError(f'provider {provider} does not exists')
-
- return icon_bytes, mime_type
-
- @staticmethod
- def list_model_tool_provider_tools(
- user_id: str, tenant_id: str, provider: str
- ) -> list[UserTool]:
- """
- list model tool provider tools
- """
- provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider)
- tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
-
- result = [
- UserTool(
- author=tool.identity.author,
- name=tool.identity.name,
- label=tool.identity.label,
- description=tool.description.human,
- parameters=tool.parameters or []
- ) for tool in tools
- ]
-
- return jsonable_encoder(result)
-
@staticmethod
def delete_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str
@@ -581,7 +375,7 @@ class ToolManageService:
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
# create provider entity
- provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type)
+ provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
@@ -602,7 +396,7 @@ class ToolManageService:
provider_controller.validate_credentials_format(credentials)
# get tool
tool = provider_controller.get_tool(tool_name)
- tool = tool.fork_tool_runtime(meta={
+ tool = tool.fork_tool_runtime(runtime={
'credentials': credentials,
'tenant_id': tenant_id,
})
@@ -612,49 +406,6 @@ class ToolManageService:
return { 'result': result or 'empty response' }
- @staticmethod
- def list_builtin_tools(
- user_id: str, tenant_id: str
- ) -> list[UserToolProvider]:
- """
- list builtin tools
- """
- # get all builtin providers
- provider_controllers = ToolManager.list_builtin_providers()
-
- # get all user added providers
- db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
- BuiltinToolProvider.tenant_id == tenant_id
- ).all() or []
-
- # find provider
- find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
-
- result: list[UserToolProvider] = []
-
- for provider_controller in provider_controllers:
- # convert provider controller to user provider
- user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
- provider_controller=provider_controller,
- db_provider=find_provider(provider_controller.identity.name),
- decrypt_credentials=True
- )
-
- # add icon
- ToolTransformService.repack_provider(user_builtin_provider)
-
- tools = provider_controller.get_tools()
- for tool in tools:
- user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
- tenant_id=tenant_id,
- tool=tool,
- credentials=user_builtin_provider.original_credentials,
- ))
-
- result.append(user_builtin_provider)
-
- return BuiltinToolProviderSort.sort(result)
-
@staticmethod
def list_api_tools(
user_id: str, tenant_id: str
@@ -672,6 +423,7 @@ class ToolManageService:
for provider in db_providers:
# convert provider controller to user provider
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
+ labels = ToolLabelManager.get_tool_labels(provider_controller)
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller,
db_provider=provider,
@@ -690,6 +442,7 @@ class ToolManageService:
tenant_id=tenant_id,
tool=tool,
credentials=user_provider.original_credentials,
+ labels=labels
))
result.append(user_provider)
diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py
new file mode 100644
index 0000000000..2503191b63
--- /dev/null
+++ b/api/services/tools/builtin_tools_manage_service.py
@@ -0,0 +1,226 @@
+import json
+import logging
+
+from core.model_runtime.utils.encoders import jsonable_encoder
+from core.tools.entities.api_entities import UserTool, UserToolProvider
+from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
+from core.tools.provider.builtin._positions import BuiltinToolProviderSort
+from core.tools.provider.tool_provider import ToolProviderController
+from core.tools.tool_label_manager import ToolLabelManager
+from core.tools.tool_manager import ToolManager
+from core.tools.utils.configuration import ToolConfigurationManager
+from extensions.ext_database import db
+from models.tools import BuiltinToolProvider
+from services.tools.tools_transform_service import ToolTransformService
+
+logger = logging.getLogger(__name__)
+
+
+class BuiltinToolManageService:
+ @staticmethod
+ def list_builtin_tool_provider_tools(
+ user_id: str, tenant_id: str, provider: str
+ ) -> list[UserTool]:
+ """
+ list builtin tool provider tools
+ """
+ provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
+ tools = provider_controller.get_tools()
+
+ tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
+ # check if user has added the provider
+ builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.provider == provider,
+ ).first()
+
+ credentials = {}
+ if builtin_provider is not None:
+ # get credentials
+ credentials = builtin_provider.credentials
+ credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
+
+ result = []
+ for tool in tools:
+ result.append(ToolTransformService.tool_to_user_tool(
+ tool=tool,
+ credentials=credentials,
+ tenant_id=tenant_id,
+ labels=ToolLabelManager.get_tool_labels(provider_controller)
+ ))
+
+ return result
+
+ @staticmethod
+ def list_builtin_provider_credentials_schema(
+ provider_name
+ ):
+ """
+ list builtin provider credentials schema
+
+ :return: the list of tool providers
+ """
+ provider = ToolManager.get_builtin_provider(provider_name)
+ return jsonable_encoder([
+ v for _, v in (provider.credentials_schema or {}).items()
+ ])
+
+ @staticmethod
+ def update_builtin_tool_provider(
+ user_id: str, tenant_id: str, provider_name: str, credentials: dict
+ ):
+ """
+ update builtin tool provider
+ """
+ # get if the provider exists
+ provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.provider == provider_name,
+ ).first()
+
+ try:
+ # get provider
+ provider_controller = ToolManager.get_builtin_provider(provider_name)
+ if not provider_controller.need_credentials:
+ raise ValueError(f'provider {provider_name} does not need credentials')
+ tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
+ # get original credentials if exists
+ if provider is not None:
+ original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
+ masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
+ # check if the credential has changed, save the original credential
+ for name, value in credentials.items():
+ if name in masked_credentials and value == masked_credentials[name]:
+ credentials[name] = original_credentials[name]
+ # validate credentials
+ provider_controller.validate_credentials(credentials)
+ # encrypt credentials
+ credentials = tool_configuration.encrypt_tool_credentials(credentials)
+ except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
+ raise ValueError(str(e))
+
+ if provider is None:
+ # create provider
+ provider = BuiltinToolProvider(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ provider=provider_name,
+ encrypted_credentials=json.dumps(credentials),
+ )
+
+ db.session.add(provider)
+ db.session.commit()
+
+ else:
+ provider.encrypted_credentials = json.dumps(credentials)
+ db.session.add(provider)
+ db.session.commit()
+
+ # delete cache
+ tool_configuration.delete_tool_credentials_cache()
+
+ return { 'result': 'success' }
+
+ @staticmethod
+ def get_builtin_tool_provider_credentials(
+ user_id: str, tenant_id: str, provider: str
+ ):
+ """
+ get builtin tool provider credentials
+ """
+ provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.provider == provider,
+ ).first()
+
+ if provider is None:
+ return {}
+
+ provider_controller = ToolManager.get_builtin_provider(provider.provider)
+ tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
+ credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
+ credentials = tool_configuration.mask_tool_credentials(credentials)
+ return credentials
+
+ @staticmethod
+ def delete_builtin_tool_provider(
+ user_id: str, tenant_id: str, provider_name: str
+ ):
+ """
+ delete tool provider
+ """
+ provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.provider == provider_name,
+ ).first()
+
+ if provider is None:
+ raise ValueError(f'you have not added provider {provider_name}')
+
+ db.session.delete(provider)
+ db.session.commit()
+
+ # delete cache
+ provider_controller = ToolManager.get_builtin_provider(provider_name)
+ tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
+ tool_configuration.delete_tool_credentials_cache()
+
+ return { 'result': 'success' }
+
+ @staticmethod
+ def get_builtin_tool_provider_icon(
+ provider: str
+ ):
+ """
+ get tool provider icon and it's mimetype
+ """
+ icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
+ with open(icon_path, 'rb') as f:
+ icon_bytes = f.read()
+
+ return icon_bytes, mime_type
+
+ @staticmethod
+ def list_builtin_tools(
+ user_id: str, tenant_id: str
+ ) -> list[UserToolProvider]:
+ """
+ list builtin tools
+ """
+ # get all builtin providers
+ provider_controllers = ToolManager.list_builtin_providers()
+
+ # get all user added providers
+ db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
+ BuiltinToolProvider.tenant_id == tenant_id
+ ).all() or []
+
+ # find provider
+ find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
+
+ result: list[UserToolProvider] = []
+
+ for provider_controller in provider_controllers:
+ # convert provider controller to user provider
+ user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
+ provider_controller=provider_controller,
+ db_provider=find_provider(provider_controller.identity.name),
+ decrypt_credentials=True
+ )
+
+ # add icon
+ ToolTransformService.repack_provider(user_builtin_provider)
+
+ tools = provider_controller.get_tools()
+ for tool in tools:
+ user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
+ tenant_id=tenant_id,
+ tool=tool,
+ credentials=user_builtin_provider.original_credentials,
+ labels=ToolLabelManager.get_tool_labels(provider_controller)
+ ))
+
+ result.append(user_builtin_provider)
+
+ return BuiltinToolProviderSort.sort(result)
+
\ No newline at end of file
diff --git a/api/services/tools/tool_labels_service.py b/api/services/tools/tool_labels_service.py
new file mode 100644
index 0000000000..8a6aa025f2
--- /dev/null
+++ b/api/services/tools/tool_labels_service.py
@@ -0,0 +1,8 @@
+from core.tools.entities.tool_entities import ToolLabel
+from core.tools.entities.values import default_tool_labels
+
+
+class ToolLabelsService:
+ @classmethod
+ def list_tool_labels(cls) -> list[ToolLabel]:
+ return default_tool_labels
\ No newline at end of file
diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py
new file mode 100644
index 0000000000..76d2f53ae8
--- /dev/null
+++ b/api/services/tools/tools_manage_service.py
@@ -0,0 +1,29 @@
+import logging
+
+from core.tools.entities.api_entities import UserToolProviderTypeLiteral
+from core.tools.tool_manager import ToolManager
+from services.tools.tools_transform_service import ToolTransformService
+
+logger = logging.getLogger(__name__)
+
+
+class ToolCommonService:
+ @staticmethod
+ def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
+ """
+ list tool providers
+
+ :return: the list of tool providers
+ """
+ providers = ToolManager.user_list_providers(
+ user_id, tenant_id, typ
+ )
+
+ # add icon
+ for provider in providers:
+ ToolTransformService.repack_provider(provider)
+
+ result = [provider.to_dict() for provider in providers]
+
+ return result
+
\ No newline at end of file
diff --git a/api/services/tools_transform_service.py b/api/services/tools/tools_transform_service.py
similarity index 75%
rename from api/services/tools_transform_service.py
rename to api/services/tools/tools_transform_service.py
index 3ef9f52e62..ba8c20d79b 100644
--- a/api/services/tools_transform_service.py
+++ b/api/services/tools/tools_transform_service.py
@@ -5,15 +5,21 @@ from typing import Optional, Union
from flask import current_app
from core.model_runtime.entities.common_entities import I18nObject
-from core.tools.entities.tool_bundle import ApiBasedToolBundle
-from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderCredentials
-from core.tools.entities.user_entities import UserTool, UserToolProvider
-from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
+from core.tools.entities.api_entities import UserTool, UserToolProvider
+from core.tools.entities.tool_bundle import ApiToolBundle
+from core.tools.entities.tool_entities import (
+ ApiProviderAuthType,
+ ToolParameter,
+ ToolProviderCredentials,
+ ToolProviderType,
+)
+from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
-from core.tools.provider.model_tool_provider import ModelToolProviderController
+from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
from core.tools.tool.tool import Tool
+from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.utils.configuration import ToolConfigurationManager
-from models.tools import ApiToolProvider, BuiltinToolProvider
+from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
logger = logging.getLogger(__name__)
@@ -26,11 +32,9 @@ class ToolTransformService:
url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ "/console/api/workspaces/current/tool-provider/")
- if provider_type == UserToolProvider.ProviderType.BUILTIN.value:
+ if provider_type == ToolProviderType.BUILT_IN.value:
return url_prefix + 'builtin/' + provider_name + '/icon'
- elif provider_type == UserToolProvider.ProviderType.MODEL.value:
- return url_prefix + 'model/' + provider_name + '/icon'
- elif provider_type == UserToolProvider.ProviderType.API.value:
+ elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]:
try:
return json.loads(icon)
except:
@@ -65,7 +69,7 @@ class ToolTransformService:
def builtin_provider_to_user_provider(
provider_controller: BuiltinToolProviderController,
db_provider: Optional[BuiltinToolProvider],
- decrypt_credentials: bool = True
+ decrypt_credentials: bool = True,
) -> UserToolProvider:
"""
convert provider controller to user provider
@@ -83,10 +87,11 @@ class ToolTransformService:
en_US=provider_controller.identity.label.en_US,
zh_Hans=provider_controller.identity.label.zh_Hans,
),
- type=UserToolProvider.ProviderType.BUILTIN,
+ type=ToolProviderType.BUILT_IN,
masked_credentials={},
is_team_authorization=False,
- tools=[]
+ tools=[],
+ labels=provider_controller.tool_labels
)
# get credentials schema
@@ -122,24 +127,62 @@ class ToolTransformService:
@staticmethod
def api_provider_to_controller(
db_provider: ApiToolProvider,
- ) -> ApiBasedToolProviderController:
+ ) -> ApiToolProviderController:
"""
convert provider controller to user provider
"""
# package tool provider controller
- controller = ApiBasedToolProviderController.from_db(
+ controller = ApiToolProviderController.from_db(
db_provider=db_provider,
auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else
ApiProviderAuthType.NONE
)
return controller
+
+ @staticmethod
+ def workflow_provider_to_controller(
+ db_provider: WorkflowToolProvider
+ ) -> WorkflowToolProviderController:
+ """
+ convert provider controller to provider
+ """
+ return WorkflowToolProviderController.from_db(db_provider)
+
+ @staticmethod
+ def workflow_provider_to_user_provider(
+ provider_controller: WorkflowToolProviderController,
+ labels: list[str] = None
+ ):
+ """
+ convert provider controller to user provider
+ """
+ return UserToolProvider(
+ id=provider_controller.provider_id,
+ author=provider_controller.identity.author,
+ name=provider_controller.identity.name,
+ description=I18nObject(
+ en_US=provider_controller.identity.description.en_US,
+ zh_Hans=provider_controller.identity.description.zh_Hans,
+ ),
+ icon=provider_controller.identity.icon,
+ label=I18nObject(
+ en_US=provider_controller.identity.label.en_US,
+ zh_Hans=provider_controller.identity.label.zh_Hans,
+ ),
+ type=ToolProviderType.WORKFLOW,
+ masked_credentials={},
+ is_team_authorization=True,
+ tools=[],
+ labels=labels or []
+ )
@staticmethod
def api_provider_to_user_provider(
- provider_controller: ApiBasedToolProviderController,
+ provider_controller: ApiToolProviderController,
db_provider: ApiToolProvider,
- decrypt_credentials: bool = True
+ decrypt_credentials: bool = True,
+ labels: list[str] = None
) -> UserToolProvider:
"""
convert provider controller to user provider
@@ -164,10 +207,11 @@ class ToolTransformService:
en_US=db_provider.name,
zh_Hans=db_provider.name,
),
- type=UserToolProvider.ProviderType.API,
+ type=ToolProviderType.API,
masked_credentials={},
is_team_authorization=True,
- tools=[]
+ tools=[],
+ labels=labels or []
)
if decrypt_credentials:
@@ -185,41 +229,19 @@ class ToolTransformService:
return result
- @staticmethod
- def model_provider_to_user_provider(
- db_provider: ModelToolProviderController,
- ) -> UserToolProvider:
- """
- convert provider controller to user provider
- """
- return UserToolProvider(
- id=db_provider.identity.name,
- author=db_provider.identity.author,
- name=db_provider.identity.name,
- description=I18nObject(
- en_US=db_provider.identity.description.en_US,
- zh_Hans=db_provider.identity.description.zh_Hans,
- ),
- icon=db_provider.identity.icon,
- label=I18nObject(
- en_US=db_provider.identity.label.en_US,
- zh_Hans=db_provider.identity.label.zh_Hans,
- ),
- type=UserToolProvider.ProviderType.MODEL,
- masked_credentials={},
- is_team_authorization=db_provider.is_active,
- )
-
@staticmethod
def tool_to_user_tool(
- tool: Union[ApiBasedToolBundle, Tool], credentials: dict = None, tenant_id: str = None
+ tool: Union[ApiToolBundle, WorkflowTool, Tool],
+ credentials: dict = None,
+ tenant_id: str = None,
+ labels: list[str] = None
) -> UserTool:
"""
convert tool to user tool
"""
if isinstance(tool, Tool):
# fork tool runtime
- tool = tool.fork_tool_runtime(meta={
+ tool = tool.fork_tool_runtime(runtime={
'credentials': credentials,
'tenant_id': tenant_id,
})
@@ -241,17 +263,15 @@ class ToolTransformService:
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
- user_tool = UserTool(
+ return UserTool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
description=tool.description.human,
- parameters=current_parameters
+ parameters=current_parameters,
+ labels=labels
)
-
- return user_tool
-
- if isinstance(tool, ApiBasedToolBundle):
+ if isinstance(tool, ApiToolBundle):
return UserTool(
author=tool.author,
name=tool.operation_id,
@@ -263,5 +283,6 @@ class ToolTransformService:
en_US=tool.summary or '',
zh_Hans=tool.summary or ''
),
- parameters=tool.parameters
+ parameters=tool.parameters,
+ labels=labels
)
\ No newline at end of file
diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py
new file mode 100644
index 0000000000..e89d94160c
--- /dev/null
+++ b/api/services/tools/workflow_tools_manage_service.py
@@ -0,0 +1,326 @@
+import json
+from datetime import datetime
+
+from sqlalchemy import or_
+
+from core.model_runtime.utils.encoders import jsonable_encoder
+from core.tools.entities.api_entities import UserToolProvider
+from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
+from core.tools.tool_label_manager import ToolLabelManager
+from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
+from extensions.ext_database import db
+from models.model import App
+from models.tools import WorkflowToolProvider
+from models.workflow import Workflow
+from services.tools.tools_transform_service import ToolTransformService
+
+
+class WorkflowToolManageService:
+ """
+ Service class for managing workflow tools.
+ """
+ @classmethod
+ def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str,
+ label: str, icon: dict, description: str,
+ parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
+ """
+ Create a workflow tool.
+ :param user_id: the user id
+ :param tenant_id: the tenant id
+ :param name: the name
+ :param icon: the icon
+ :param description: the description
+ :param parameters: the parameters
+ :param privacy_policy: the privacy policy
+ :return: the created tool
+ """
+ WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
+
+ # check if the name is unique
+ existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.tenant_id == tenant_id,
+ # name or app_id
+ or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id)
+ ).first()
+
+ if existing_workflow_tool_provider is not None:
+ raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists')
+
+ app: App = db.session.query(App).filter(
+ App.id == workflow_app_id,
+ App.tenant_id == tenant_id
+ ).first()
+
+ if app is None:
+ raise ValueError(f'App {workflow_app_id} not found')
+
+ workflow: Workflow = app.workflow
+ if workflow is None:
+ raise ValueError(f'Workflow not found for app {workflow_app_id}')
+
+ workflow_tool_provider = WorkflowToolProvider(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ app_id=workflow_app_id,
+ name=name,
+ label=label,
+ icon=json.dumps(icon),
+ description=description,
+ parameter_configuration=json.dumps(parameters),
+ privacy_policy=privacy_policy,
+ version=workflow.version,
+ )
+
+ try:
+ WorkflowToolProviderController.from_db(workflow_tool_provider)
+ except Exception as e:
+ raise ValueError(str(e))
+
+ db.session.add(workflow_tool_provider)
+ db.session.commit()
+
+ return {
+ 'result': 'success'
+ }
+
+
+ @classmethod
+ def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str,
+ name: str, label: str, icon: dict, description: str,
+ parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
+ """
+ Update a workflow tool.
+ :param user_id: the user id
+ :param tenant_id: the tenant id
+ :param tool: the tool
+ :return: the updated tool
+ """
+ WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
+
+ # check if the name is unique
+ existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.tenant_id == tenant_id,
+ WorkflowToolProvider.name == name,
+ WorkflowToolProvider.id != workflow_tool_id
+ ).first()
+
+ if existing_workflow_tool_provider is not None:
+ raise ValueError(f'Tool with name {name} already exists')
+
+ workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.tenant_id == tenant_id,
+ WorkflowToolProvider.id == workflow_tool_id
+ ).first()
+
+ if workflow_tool_provider is None:
+ raise ValueError(f'Tool {workflow_tool_id} not found')
+
+ app: App = db.session.query(App).filter(
+ App.id == workflow_tool_provider.app_id,
+ App.tenant_id == tenant_id
+ ).first()
+
+ if app is None:
+ raise ValueError(f'App {workflow_tool_provider.app_id} not found')
+
+ workflow: Workflow = app.workflow
+ if workflow is None:
+ raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}')
+
+ workflow_tool_provider.name = name
+ workflow_tool_provider.label = label
+ workflow_tool_provider.icon = json.dumps(icon)
+ workflow_tool_provider.description = description
+ workflow_tool_provider.parameter_configuration = json.dumps(parameters)
+ workflow_tool_provider.privacy_policy = privacy_policy
+ workflow_tool_provider.version = workflow.version
+ workflow_tool_provider.updated_at = datetime.now()
+
+ try:
+ WorkflowToolProviderController.from_db(workflow_tool_provider)
+ except Exception as e:
+ raise ValueError(str(e))
+
+ db.session.add(workflow_tool_provider)
+ db.session.commit()
+
+ if labels is not None:
+ ToolLabelManager.update_tool_labels(
+ ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
+ labels
+ )
+
+ return {
+ 'result': 'success'
+ }
+
+ @classmethod
+ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
+ """
+ List workflow tools.
+ :param user_id: the user id
+ :param tenant_id: the tenant id
+ :return: the list of tools
+ """
+ db_tools = db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.tenant_id == tenant_id
+ ).all()
+
+ tools = []
+ for provider in db_tools:
+ try:
+ tools.append(ToolTransformService.workflow_provider_to_controller(provider))
+ except:
+ # skip deleted tools
+ pass
+
+ labels = ToolLabelManager.get_tools_labels(tools)
+
+ result = []
+
+ for tool in tools:
+ user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
+ provider_controller=tool,
+ labels=labels.get(tool.provider_id, [])
+ )
+ ToolTransformService.repack_provider(user_tool_provider)
+ user_tool_provider.tools = [
+ ToolTransformService.tool_to_user_tool(
+ tool.get_tools(user_id, tenant_id)[0],
+ labels=labels.get(tool.provider_id, [])
+ )
+ ]
+ result.append(user_tool_provider)
+
+ return result
+
+ @classmethod
+ def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
+ """
+ Delete a workflow tool.
+ :param user_id: the user id
+ :param tenant_id: the tenant id
+ :param workflow_app_id: the workflow app id
+ """
+ db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.tenant_id == tenant_id,
+ WorkflowToolProvider.id == workflow_tool_id
+ ).delete()
+
+ db.session.commit()
+
+ return {
+ 'result': 'success'
+ }
+
+ @classmethod
+ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
+ """
+ Get a workflow tool.
+ :param user_id: the user id
+ :param tenant_id: the tenant id
+ :param workflow_app_id: the workflow app id
+ :return: the tool
+ """
+ db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.tenant_id == tenant_id,
+ WorkflowToolProvider.id == workflow_tool_id
+ ).first()
+
+ if db_tool is None:
+ raise ValueError(f'Tool {workflow_tool_id} not found')
+
+ workflow_app: App = db.session.query(App).filter(
+ App.id == db_tool.app_id,
+ App.tenant_id == tenant_id
+ ).first()
+
+ if workflow_app is None:
+ raise ValueError(f'App {db_tool.app_id} not found')
+
+ tool = ToolTransformService.workflow_provider_to_controller(db_tool)
+
+ return {
+ 'name': db_tool.name,
+ 'label': db_tool.label,
+ 'workflow_tool_id': db_tool.id,
+ 'workflow_app_id': db_tool.app_id,
+ 'icon': json.loads(db_tool.icon),
+ 'description': db_tool.description,
+ 'parameters': jsonable_encoder(db_tool.parameter_configurations),
+ 'tool': ToolTransformService.tool_to_user_tool(
+ tool.get_tools(user_id, tenant_id)[0],
+ labels=ToolLabelManager.get_tool_labels(tool)
+ ),
+ 'synced': workflow_app.workflow.version == db_tool.version,
+ 'privacy_policy': db_tool.privacy_policy,
+ }
+
+ @classmethod
+ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
+ """
+ Get a workflow tool.
+ :param user_id: the user id
+ :param tenant_id: the tenant id
+ :param workflow_app_id: the workflow app id
+ :return: the tool
+ """
+ db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.tenant_id == tenant_id,
+ WorkflowToolProvider.app_id == workflow_app_id
+ ).first()
+
+ if db_tool is None:
+ raise ValueError(f'Tool {workflow_app_id} not found')
+
+ workflow_app: App = db.session.query(App).filter(
+ App.id == db_tool.app_id,
+ App.tenant_id == tenant_id
+ ).first()
+
+ if workflow_app is None:
+ raise ValueError(f'App {db_tool.app_id} not found')
+
+ tool = ToolTransformService.workflow_provider_to_controller(db_tool)
+
+ return {
+ 'name': db_tool.name,
+ 'label': db_tool.label,
+ 'workflow_tool_id': db_tool.id,
+ 'workflow_app_id': db_tool.app_id,
+ 'icon': json.loads(db_tool.icon),
+ 'description': db_tool.description,
+ 'parameters': jsonable_encoder(db_tool.parameter_configurations),
+ 'tool': ToolTransformService.tool_to_user_tool(
+ tool.get_tools(user_id, tenant_id)[0],
+ labels=ToolLabelManager.get_tool_labels(tool)
+ ),
+ 'synced': workflow_app.workflow.version == db_tool.version,
+ 'privacy_policy': db_tool.privacy_policy
+ }
+
+ @classmethod
+ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
+ """
+ List workflow tool provider tools.
+ :param user_id: the user id
+ :param tenant_id: the tenant id
+ :param workflow_app_id: the workflow app id
+ :return: the list of tools
+ """
+ db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
+ WorkflowToolProvider.tenant_id == tenant_id,
+ WorkflowToolProvider.id == workflow_tool_id
+ ).first()
+
+ if db_tool is None:
+ raise ValueError(f'Tool {workflow_tool_id} not found')
+
+ tool = ToolTransformService.workflow_provider_to_controller(db_tool)
+
+ return [
+ ToolTransformService.tool_to_user_tool(
+ tool.get_tools(user_id, tenant_id)[0],
+ labels=ToolLabelManager.get_tool_labels(tool)
+ )
+ ]
\ No newline at end of file
diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py
index 138a5d5786..d76cd4c7ff 100644
--- a/api/services/workflow/workflow_converter.py
+++ b/api/services/workflow/workflow_converter.py
@@ -305,7 +305,7 @@ class WorkflowConverter:
}
request_body_json = json.dumps(request_body)
- request_body_json = request_body_json.replace('\{\{', '{{').replace('\}\}', '}}')
+ request_body_json = request_body_json.replace(r'\{\{', '{{').replace(r'\}\}', '}}')
http_request_node = {
"id": f"http_request_{index}",
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index 01fd3aa4a1..6235ecf0a3 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -21,6 +21,7 @@ from models.workflow import (
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
+from services.errors.app import WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter
@@ -63,13 +64,20 @@ class WorkflowService:
def sync_draft_workflow(self, app_model: App,
graph: dict,
features: dict,
+ unique_hash: Optional[str],
account: Account) -> Workflow:
"""
Sync draft workflow
+ :raises WorkflowHashNotEqualError
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
+ if workflow:
+ # validate unique hash
+ if workflow.unique_hash != unique_hash:
+ raise WorkflowHashNotEqualError()
+
# validate features structure
self.validate_features_structure(
app_model=app_model,
diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py
index d6dc970477..67cc03bdeb 100644
--- a/api/tasks/batch_create_segment_to_index_task.py
+++ b/api/tasks/batch_create_segment_to_index_task.py
@@ -2,7 +2,6 @@ import datetime
import logging
import time
import uuid
-from typing import cast
import click
from celery import shared_task
@@ -11,7 +10,6 @@ from sqlalchemy import func
from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
@@ -59,16 +57,12 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s
model=dataset.embedding_model
)
- model_type_instance = embedding_model.model_type_instance
- model_type_instance = cast(TextEmbeddingModel, model_type_instance)
for segment in content:
content = segment['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
- tokens = model_type_instance.get_num_tokens(
- model=embedding_model.model,
- credentials=embedding_model.credentials,
+ tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content]
) if embedding_model else 0
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example
index 9cd04b4764..f29e5ef4d6 100644
--- a/api/tests/integration_tests/.env.example
+++ b/api/tests/integration_tests/.env.example
@@ -73,4 +73,10 @@ MOCK_SWITCH=false
# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=
-CODE_EXECUTION_API_KEY=
\ No newline at end of file
+CODE_EXECUTION_API_KEY=
+
+# Volcengine MaaS Credentials
+VOLC_API_KEY=
+VOLC_SECRET_KEY=
+VOLC_MODEL_ENDPOINT_ID=
+VOLC_EMBEDDING_ENDPOINT_ID=
\ No newline at end of file
diff --git a/api/tests/integration_tests/model_runtime/localai/test_rerank.py b/api/tests/integration_tests/model_runtime/localai/test_rerank.py
new file mode 100644
index 0000000000..a75439337e
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/localai/test_rerank.py
@@ -0,0 +1,158 @@
+import os
+
+import pytest
+from api.core.model_runtime.entities.rerank_entities import RerankResult
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel
+
+
+def test_validate_credentials_for_chat_model():
+ model = LocalaiRerankModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model='bge-reranker-v2-m3',
+ credentials={
+ 'server_url': 'hahahaha',
+ 'completion_type': 'completion',
+ }
+ )
+
+ model.validate_credentials(
+ model='bge-reranker-base',
+ credentials={
+ 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
+ 'completion_type': 'completion',
+ }
+ )
+
+def test_invoke_rerank_model():
+ model = LocalaiRerankModel()
+
+ response = model.invoke(
+ model='bge-reranker-base',
+ credentials={
+ 'server_url': os.environ.get('LOCALAI_SERVER_URL')
+ },
+ query='Organic skincare products for sensitive skin',
+ docs=[
+ "Eco-friendly kitchenware for modern homes",
+ "Biodegradable cleaning supplies for eco-conscious consumers",
+ "Organic cotton baby clothes for sensitive skin",
+ "Natural organic skincare range for sensitive skin",
+ "Tech gadgets for smart homes: 2024 edition",
+ "Sustainable gardening tools and compost solutions",
+ "Sensitive skin-friendly facial cleansers and toners",
+ "Organic food wraps and storage solutions",
+ "Yoga mats made from recycled materials"
+ ],
+ top_n=3,
+ score_threshold=0.75,
+ user="abc-123"
+ )
+
+ assert isinstance(response, RerankResult)
+ assert len(response.docs) == 3
+import os
+
+import pytest
+from api.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel
+
+
+def test_validate_credentials_for_chat_model():
+ model = LocalaiRerankModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model='bge-reranker-v2-m3',
+ credentials={
+ 'server_url': 'hahahaha',
+ 'completion_type': 'completion',
+ }
+ )
+
+ model.validate_credentials(
+ model='bge-reranker-base',
+ credentials={
+ 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
+ 'completion_type': 'completion',
+ }
+ )
+
+def test_invoke_rerank_model():
+ model = LocalaiRerankModel()
+
+ response = model.invoke(
+ model='bge-reranker-base',
+ credentials={
+ 'server_url': os.environ.get('LOCALAI_SERVER_URL')
+ },
+ query='Organic skincare products for sensitive skin',
+ docs=[
+ "Eco-friendly kitchenware for modern homes",
+ "Biodegradable cleaning supplies for eco-conscious consumers",
+ "Organic cotton baby clothes for sensitive skin",
+ "Natural organic skincare range for sensitive skin",
+ "Tech gadgets for smart homes: 2024 edition",
+ "Sustainable gardening tools and compost solutions",
+ "Sensitive skin-friendly facial cleansers and toners",
+ "Organic food wraps and storage solutions",
+ "Yoga mats made from recycled materials"
+ ],
+ top_n=3,
+ score_threshold=0.75,
+ user="abc-123"
+ )
+
+ assert isinstance(response, RerankResult)
+ assert len(response.docs) == 3
+
+def test__invoke():
+ model = LocalaiRerankModel()
+
+ # Test case 1: Empty docs
+ result = model._invoke(
+ model='bge-reranker-base',
+ credentials={
+ 'server_url': 'https://example.com',
+ 'api_key': '1234567890'
+ },
+ query='Organic skincare products for sensitive skin',
+ docs=[],
+ top_n=3,
+ score_threshold=0.75,
+ user="abc-123"
+ )
+ assert isinstance(result, RerankResult)
+ assert len(result.docs) == 0
+
+ # Test case 2: Valid invocation
+ result = model._invoke(
+ model='bge-reranker-base',
+ credentials={
+ 'server_url': 'https://example.com',
+ 'api_key': '1234567890'
+ },
+ query='Organic skincare products for sensitive skin',
+ docs=[
+ "Eco-friendly kitchenware for modern homes",
+ "Biodegradable cleaning supplies for eco-conscious consumers",
+ "Organic cotton baby clothes for sensitive skin",
+ "Natural organic skincare range for sensitive skin",
+ "Tech gadgets for smart homes: 2024 edition",
+ "Sustainable gardening tools and compost solutions",
+ "Sensitive skin-friendly facial cleansers and toners",
+ "Organic food wraps and storage solutions",
+ "Yoga mats made from recycled materials"
+ ],
+ top_n=3,
+ score_threshold=0.75,
+ user="abc-123"
+ )
+ assert isinstance(result, RerankResult)
+ assert len(result.docs) == 3
+ assert all(isinstance(doc, RerankDocument) for doc in result.docs)
\ No newline at end of file
diff --git a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py
new file mode 100644
index 0000000000..3fd2ebed4f
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py
@@ -0,0 +1,54 @@
+import os
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.localai.speech2text.speech2text import LocalAISpeech2text
+
+
+def test_validate_credentials():
+ model = LocalAISpeech2text()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model='whisper-1',
+ credentials={
+ 'server_url': 'invalid_url'
+ }
+ )
+
+ model.validate_credentials(
+ model='whisper-1',
+ credentials={
+ 'server_url': os.environ.get('LOCALAI_SERVER_URL')
+ }
+ )
+
+
+def test_invoke_model():
+ model = LocalAISpeech2text()
+
+ # Get the directory of the current file
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+
+ # Get assets directory
+ assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
+
+ # Construct the path to the audio file
+ audio_file_path = os.path.join(assets_dir, 'audio.mp3')
+
+ # Open the file and get the file object
+ with open(audio_file_path, 'rb') as audio_file:
+ file = audio_file
+
+ result = model.invoke(
+ model='whisper-1',
+ credentials={
+ 'server_url': os.environ.get('LOCALAI_SERVER_URL')
+ },
+ file=file,
+ user="abc-123"
+ )
+
+ assert isinstance(result, str)
+ assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
\ No newline at end of file
diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py b/api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py
new file mode 100644
index 0000000000..3b399d604e
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py
@@ -0,0 +1,85 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.volcengine_maas.text_embedding.text_embedding import (
+ VolcengineMaaSTextEmbeddingModel,
+)
+
+
+def test_validate_credentials():
+ model = VolcengineMaaSTextEmbeddingModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': 'INVALID',
+ 'volc_secret_access_key': 'INVALID',
+ 'endpoint_id': 'INVALID',
+ 'base_model_name': 'Doubao-embedding',
+ }
+ )
+
+ model.validate_credentials(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+ 'base_model_name': 'Doubao-embedding',
+ },
+ )
+
+
+def test_invoke_model():
+ model = VolcengineMaaSTextEmbeddingModel()
+
+ result = model.invoke(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+ 'base_model_name': 'Doubao-embedding',
+ },
+ texts=[
+ "hello",
+ "world"
+ ],
+ user="abc-123"
+ )
+
+ assert isinstance(result, TextEmbeddingResult)
+ assert len(result.embeddings) == 2
+ assert result.usage.total_tokens > 0
+
+
+def test_get_num_tokens():
+ model = VolcengineMaaSTextEmbeddingModel()
+
+ num_tokens = model.get_num_tokens(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+ 'base_model_name': 'Doubao-embedding',
+ },
+ texts=[
+ "hello",
+ "world"
+ ]
+ )
+
+ assert num_tokens == 2
diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py
new file mode 100644
index 0000000000..63835d0263
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py
@@ -0,0 +1,131 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.volcengine_maas.llm.llm import VolcengineMaaSLargeLanguageModel
+
+
+def test_validate_credentials_for_chat_model():
+ model = VolcengineMaaSLargeLanguageModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': 'INVALID',
+ 'volc_secret_access_key': 'INVALID',
+ 'endpoint_id': 'INVALID',
+ }
+ )
+
+ model.validate_credentials(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+ }
+ )
+
+
+def test_invoke_model():
+ model = VolcengineMaaSLargeLanguageModel()
+
+ response = model.invoke(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+ 'base_model_name': 'Skylark2-pro-4k',
+ },
+ prompt_messages=[
+ UserPromptMessage(
+ content='Hello World!'
+ )
+ ],
+ model_parameters={
+ 'temperature': 0.7,
+ 'top_p': 1.0,
+ 'top_k': 1,
+ },
+ stop=['you'],
+ user="abc-123",
+ stream=False
+ )
+
+ assert isinstance(response, LLMResult)
+ assert len(response.message.content) > 0
+ assert response.usage.total_tokens > 0
+
+
+def test_invoke_stream_model():
+ model = VolcengineMaaSLargeLanguageModel()
+
+ response = model.invoke(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+ 'base_model_name': 'Skylark2-pro-4k',
+ },
+ prompt_messages=[
+ UserPromptMessage(
+ content='Hello World!'
+ )
+ ],
+ model_parameters={
+ 'temperature': 0.7,
+ 'top_p': 1.0,
+ 'top_k': 1,
+ },
+ stop=['you'],
+ stream=True,
+ user="abc-123"
+ )
+
+ assert isinstance(response, Generator)
+ for chunk in response:
+ assert isinstance(chunk, LLMResultChunk)
+ assert isinstance(chunk.delta, LLMResultChunkDelta)
+ assert isinstance(chunk.delta.message, AssistantPromptMessage)
+ assert len(
+ chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_get_num_tokens():
+ model = VolcengineMaaSLargeLanguageModel()
+
+ response = model.get_num_tokens(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+ 'base_model_name': 'Skylark2-pro-4k',
+ },
+ prompt_messages=[
+ UserPromptMessage(
+ content='Hello World!'
+ )
+ ],
+ tools=[]
+ )
+
+ assert isinstance(response, int)
+ assert response == 6
diff --git a/api/tests/integration_tests/utils/test_module_import_helper.py b/api/tests/integration_tests/utils/test_module_import_helper.py
index 39ac41b648..256c9a911f 100644
--- a/api/tests/integration_tests/utils/test_module_import_helper.py
+++ b/api/tests/integration_tests/utils/test_module_import_helper.py
@@ -1,6 +1,6 @@
import os
-from core.utils.module_import_helper import import_module_from_source, load_single_subclass_from_source
+from core.helper.module_import_helper import import_module_from_source, load_single_subclass_from_source
from tests.integration_tests.utils.parent_class import ParentClass
diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py
index a2b9669b90..89a40a92be 100644
--- a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py
+++ b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py
@@ -6,7 +6,7 @@ from tests.integration_tests.vdb.test_vector_store import (
)
-class TestPgvectoRSVector(AbstractVectorTest):
+class PGVectoRSVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = PGVectoRS(
@@ -34,4 +34,4 @@ class TestPgvectoRSVector(AbstractVectorTest):
assert len(ids) == 1
def test_pgvecot_rs(setup_mock_redis):
- TestPgvectoRSVector().run_all_tests()
+ PGVectoRSVectorTest().run_all_tests()
diff --git a/api/tests/integration_tests/vdb/pgvector/__init__.py b/api/tests/integration_tests/vdb/pgvector/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py
new file mode 100644
index 0000000000..851599c7ce
--- /dev/null
+++ b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py
@@ -0,0 +1,30 @@
+from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
+from core.rag.models.document import Document
+from tests.integration_tests.vdb.test_vector_store import (
+ AbstractVectorTest,
+ get_example_text,
+ setup_mock_redis,
+)
+
+
+class PGVectorTest(AbstractVectorTest):
+ def __init__(self):
+ super().__init__()
+ self.vector = PGVector(
+ collection_name=self.collection_name,
+ config=PGVectorConfig(
+ host="localhost",
+ port=5433,
+ user="postgres",
+ password="difyai123456",
+ database="dify",
+ ),
+ )
+
+ def search_by_full_text(self):
+ hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
+ assert len(hits_by_full_text) == 0
+
+
+def test_pgvector(setup_mock_redis):
+ PGVectorTest().run_all_tests()
diff --git a/api/tests/integration_tests/vdb/tidb_vector/__init__.py b/api/tests/integration_tests/vdb/tidb_vector/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py
new file mode 100644
index 0000000000..837a228a55
--- /dev/null
+++ b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py
@@ -0,0 +1,63 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
+from models.dataset import Document
+from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
+
+
+@pytest.fixture
+def tidb_vector():
+ return TiDBVector(
+ collection_name='test_collection',
+ config=TiDBVectorConfig(
+ host="xxx.eu-central-1.xxx.aws.tidbcloud.com",
+ port="4000",
+ user="xxx.root",
+ password="xxxxxx",
+ database="dify"
+ )
+ )
+
+
+class TiDBVectorTest(AbstractVectorTest):
+ def __init__(self, vector):
+ super().__init__()
+ self.vector = vector
+
+ def text_exists(self):
+ exist = self.vector.text_exists(self.example_doc_id)
+ assert exist == False
+
+ def search_by_vector(self):
+ hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
+ assert len(hits_by_vector) == 0
+
+ def search_by_full_text(self):
+ hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
+ assert len(hits_by_full_text) == 0
+
+ def get_ids_by_metadata_field(self):
+ ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
+ assert len(ids) == 0
+
+ def delete_by_document_id(self):
+ self.vector.delete_by_document_id(document_id=self.example_doc_id)
+
+
+def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session):
+ TiDBVectorTest(vector=tidb_vector).run_all_tests()
+
+
+@pytest.fixture
+def mock_session():
+ with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session:
+ yield mock_session
+
+
+@pytest.fixture
+def setup_tidbvector_mock(tidb_vector, mock_session):
+ with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'):
+ with patch.object(tidb_vector._engine, 'connect'):
+ yield tidb_vector
diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py
index 38517cf448..13f992136e 100644
--- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py
+++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py
@@ -1,25 +1,29 @@
import os
-from typing import Literal
+from typing import Literal, Optional
import pytest
from _pytest.monkeypatch import MonkeyPatch
+from jinja2 import Template
-from core.helper.code_executor.code_executor import CodeExecutor
+from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
+from core.helper.code_executor.entities import CodeDependency
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
class MockedCodeExecutor:
@classmethod
- def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict:
+ def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'],
+ code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
# invoke directly
- if language == 'python3':
- return {
- "result": 3
- }
- elif language == 'jinja2':
- return {
- "result": "3"
- }
+ match language:
+ case CodeLanguage.PYTHON3:
+ return {
+ "result": 3
+ }
+ case CodeLanguage.JINJA2:
+ return {
+ "result": Template(code).render(inputs)
+ }
@pytest.fixture
def setup_code_executor_mock(request, monkeypatch: MonkeyPatch):
diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py
new file mode 100644
index 0000000000..ae6e7ceaa7
--- /dev/null
+++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py
@@ -0,0 +1,11 @@
+import pytest
+
+from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
+
+CODE_LANGUAGE = 'unsupported_language'
+
+
+def test_unsupported_with_code_template():
+ with pytest.raises(CodeExecutionException) as e:
+ CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code='', inputs={})
+ assert str(e.value) == f'Unsupported language {CODE_LANGUAGE}'
diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py
index c794ae8e4b..2d798eb9c2 100644
--- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py
+++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py
@@ -1,6 +1,10 @@
-from core.helper.code_executor.code_executor import CodeExecutor
+from textwrap import dedent
-CODE_LANGUAGE = 'javascript'
+from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
+from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
+from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
+
+CODE_LANGUAGE = CodeLanguage.JAVASCRIPT
def test_javascript_plain():
@@ -10,9 +14,30 @@ def test_javascript_plain():
def test_javascript_json():
- code = """
-obj = {'Hello': 'World'}
-console.log(JSON.stringify(obj))
- """
+ code = dedent("""
+ obj = {'Hello': 'World'}
+ console.log(JSON.stringify(obj))
+ """)
result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code)
assert result == '{"Hello":"World"}\n'
+
+
+def test_javascript_with_code_template():
+ result = CodeExecutor.execute_workflow_code_template(
+ language=CODE_LANGUAGE, code=JavascriptCodeProvider.get_default_code(),
+ inputs={'arg1': 'Hello', 'arg2': 'World'})
+ assert result == {'result': 'HelloWorld'}
+
+
+def test_javascript_list_default_available_packages():
+ packages = JavascriptCodeProvider.get_default_available_packages()
+
+ # no default packages available for javascript
+ assert len(packages) == 0
+
+
+def test_javascript_get_runner_script():
+ runner_script = NodeJsTemplateTransformer.get_runner_script()
+ assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1
+ assert runner_script.count(NodeJsTemplateTransformer._inputs_placeholder) == 1
+ assert runner_script.count(NodeJsTemplateTransformer._result_tag) == 2
diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jina2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jina2.py
deleted file mode 100644
index aae3c7acec..0000000000
--- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jina2.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import base64
-
-from core.helper.code_executor.code_executor import CodeExecutor
-from core.helper.code_executor.jinja2_transformer import JINJA2_PRELOAD, PYTHON_RUNNER
-
-CODE_LANGUAGE = 'jinja2'
-
-
-def test_jinja2():
- template = 'Hello {{template}}'
- inputs = base64.b64encode(b'{"template": "World"}').decode('utf-8')
- code = PYTHON_RUNNER.replace('{{code}}', template).replace('{{inputs}}', inputs)
- result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload=JINJA2_PRELOAD, code=code)
- assert result == '<>Hello World<>\n'
diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py
new file mode 100644
index 0000000000..425f4cbdd4
--- /dev/null
+++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py
@@ -0,0 +1,31 @@
+import base64
+
+from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
+from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
+
+CODE_LANGUAGE = CodeLanguage.JINJA2
+
+
+def test_jinja2():
+ template = 'Hello {{template}}'
+ inputs = base64.b64encode(b'{"template": "World"}').decode('utf-8')
+ code = (Jinja2TemplateTransformer.get_runner_script()
+ .replace(Jinja2TemplateTransformer._code_placeholder, template)
+ .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs))
+ result = CodeExecutor.execute_code(language=CODE_LANGUAGE,
+ preload=Jinja2TemplateTransformer.get_preload_script(),
+ code=code)
+ assert result == '<>Hello World<>\n'
+
+
+def test_jinja2_with_code_template():
+ result = CodeExecutor.execute_workflow_code_template(
+ language=CODE_LANGUAGE, code='Hello {{template}}', inputs={'template': 'World'})
+ assert result == {'result': 'Hello World'}
+
+
+def test_jinja2_get_runner_script():
+ runner_script = Jinja2TemplateTransformer.get_runner_script()
+ assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
+ assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
+ assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py
index 1983bc5e6b..d265011d4c 100644
--- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py
+++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py
@@ -1,6 +1,11 @@
-from core.helper.code_executor.code_executor import CodeExecutor
+import json
+from textwrap import dedent
-CODE_LANGUAGE = 'python3'
+from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
+from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
+from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
+
+CODE_LANGUAGE = CodeLanguage.PYTHON3
def test_python3_plain():
@@ -10,9 +15,31 @@ def test_python3_plain():
def test_python3_json():
- code = """
-import json
-print(json.dumps({'Hello': 'World'}))
- """
+ code = dedent("""
+ import json
+ print(json.dumps({'Hello': 'World'}))
+ """)
result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code)
assert result == '{"Hello": "World"}\n'
+
+
+def test_python3_with_code_template():
+ result = CodeExecutor.execute_workflow_code_template(
+ language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={'arg1': 'Hello', 'arg2': 'World'})
+ assert result == {'result': 'HelloWorld'}
+
+
+def test_python3_list_default_available_packages():
+ packages = Python3CodeProvider.get_default_available_packages()
+ assert len(packages) > 0
+ assert {'requests', 'httpx'}.issubset(p['name'] for p in packages)
+
+ # check JSON serializable
+ assert len(str(json.dumps(packages))) > 0
+
+
+def test_python3_get_runner_script():
+ runner_script = Python3TemplateTransformer.get_runner_script()
+ assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1
+ assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1
+ assert runner_script.count(Python3TemplateTransformer._result_tag) == 2
diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py
index c41d51caf7..15cf5367d3 100644
--- a/api/tests/integration_tests/workflow/nodes/test_code.py
+++ b/api/tests/integration_tests/workflow/nodes/test_code.py
@@ -4,6 +4,7 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.code.code_node import CodeNode
from models.workflow import WorkflowNodeExecutionStatus
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@@ -25,7 +26,8 @@ def test_execute_code(setup_code_executor_mock):
app_id='1',
workflow_id='1',
user_id='1',
- user_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.WEB_APP,
config={
'id': '1',
'data': {
@@ -78,7 +80,8 @@ def test_execute_code_output_validator(setup_code_executor_mock):
app_id='1',
workflow_id='1',
user_id='1',
- user_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.WEB_APP,
config={
'id': '1',
'data': {
@@ -132,7 +135,8 @@ def test_execute_code_output_validator_depth():
app_id='1',
workflow_id='1',
user_id='1',
- user_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.WEB_APP,
config={
'id': '1',
'data': {
@@ -285,7 +289,8 @@ def test_execute_code_output_object_list():
app_id='1',
workflow_id='1',
user_id='1',
- user_from=InvokeFrom.WEB_APP,
+ invoke_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
config={
'id': '1',
'data': {
diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py
index 63b6b7d962..ffa2741e55 100644
--- a/api/tests/integration_tests/workflow/nodes/test_http.py
+++ b/api/tests/integration_tests/workflow/nodes/test_http.py
@@ -2,6 +2,7 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
@@ -10,7 +11,8 @@ BASIC_NODE_DATA = {
'app_id': '1',
'workflow_id': '1',
'user_id': '1',
- 'user_from': InvokeFrom.WEB_APP,
+ 'user_from': UserFrom.ACCOUNT,
+ 'invoke_from': InvokeFrom.WEB_APP,
}
# construct variable pool
@@ -38,6 +40,7 @@ def test_get(setup_http_mock):
'headers': 'X-Header:123',
'params': 'A:b',
'body': None,
+ 'mask_authorization_header': False,
}
}, **BASIC_NODE_DATA)
@@ -95,6 +98,7 @@ def test_custom_authorization_header(setup_http_mock):
'headers': 'X-Header:123',
'params': 'A:b',
'body': None,
+ 'mask_authorization_header': False,
}
}, **BASIC_NODE_DATA)
@@ -126,6 +130,7 @@ def test_template(setup_http_mock):
'headers': 'X-Header:123\nX-Header2:{{#a.b123.args2#}}',
'params': 'A:b\nTemplate:{{#a.b123.args2#}}',
'body': None,
+ 'mask_authorization_header': False,
}
}, **BASIC_NODE_DATA)
@@ -161,6 +166,7 @@ def test_json(setup_http_mock):
'type': 'json',
'data': '{"a": "{{#a.b123.args1#}}"}'
},
+ 'mask_authorization_header': False,
}
}, **BASIC_NODE_DATA)
@@ -193,6 +199,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
'type': 'x-www-form-urlencoded',
'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}'
},
+ 'mask_authorization_header': False,
}
}, **BASIC_NODE_DATA)
@@ -225,6 +232,7 @@ def test_form_data(setup_http_mock):
'type': 'form-data',
'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}'
},
+ 'mask_authorization_header': False,
}
}, **BASIC_NODE_DATA)
@@ -260,6 +268,7 @@ def test_none_data(setup_http_mock):
'type': 'none',
'data': '123123123'
},
+ 'mask_authorization_header': False,
}
}, **BASIC_NODE_DATA)
diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py
index 8a8a58d59f..a150be3c00 100644
--- a/api/tests/integration_tests/workflow/nodes/test_llm.py
+++ b/api/tests/integration_tests/workflow/nodes/test_llm.py
@@ -1,9 +1,10 @@
+import json
import os
from unittest.mock import MagicMock
import pytest
-from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
+from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
from core.model_manager import ModelInstance
@@ -19,6 +20,7 @@ from models.workflow import WorkflowNodeExecutionStatus
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
+from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@@ -28,6 +30,7 @@ def test_execute_llm(setup_openai_mock):
app_id='1',
workflow_id='1',
user_id='1',
+ invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
config={
'id': 'llm',
@@ -89,7 +92,8 @@ def test_execute_llm(setup_openai_mock):
provider=CustomProviderConfiguration(
credentials=credentials
)
- )
+ ),
+ model_settings=[]
),
provider_instance=provider_instance,
model_type_instance=model_type_instance
@@ -116,3 +120,120 @@ def test_execute_llm(setup_openai_mock):
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['text'] is not None
assert result.outputs['usage']['total_tokens'] > 0
+
+@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True)
+@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
+ """
+ Test execute LLM node with jinja2
+ """
+ node = LLMNode(
+ tenant_id='1',
+ app_id='1',
+ workflow_id='1',
+ user_id='1',
+ invoke_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
+ config={
+ 'id': 'llm',
+ 'data': {
+ 'title': '123',
+ 'type': 'llm',
+ 'model': {
+ 'provider': 'openai',
+ 'name': 'gpt-3.5-turbo',
+ 'mode': 'chat',
+ 'completion_params': {}
+ },
+ 'prompt_config': {
+ 'jinja2_variables': [{
+ 'variable': 'sys_query',
+ 'value_selector': ['sys', 'query']
+ }, {
+ 'variable': 'output',
+ 'value_selector': ['abc', 'output']
+ }]
+ },
+ 'prompt_template': [
+ {
+ 'role': 'system',
+ 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}',
+ 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.',
+ 'edition_type': 'jinja2'
+ },
+ {
+ 'role': 'user',
+ 'text': '{{#sys.query#}}',
+ 'jinja2_text': '{{sys_query}}',
+ 'edition_type': 'basic'
+ }
+ ],
+ 'memory': None,
+ 'context': {
+ 'enabled': False
+ },
+ 'vision': {
+ 'enabled': False
+ }
+ }
+ }
+ )
+
+ # construct variable pool
+ pool = VariablePool(system_variables={
+ SystemVariable.QUERY: 'what\'s the weather today?',
+ SystemVariable.FILES: [],
+ SystemVariable.CONVERSATION_ID: 'abababa',
+ SystemVariable.USER_ID: 'aaa'
+ }, user_inputs={})
+ pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
+
+ credentials = {
+ 'openai_api_key': os.environ.get('OPENAI_API_KEY')
+ }
+
+ provider_instance = ModelProviderFactory().get_provider_instance('openai')
+ model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
+ provider_model_bundle = ProviderModelBundle(
+ configuration=ProviderConfiguration(
+ tenant_id='1',
+ provider=provider_instance.get_provider_schema(),
+ preferred_provider_type=ProviderType.CUSTOM,
+ using_provider_type=ProviderType.CUSTOM,
+ system_configuration=SystemConfiguration(
+ enabled=False
+ ),
+ custom_configuration=CustomConfiguration(
+ provider=CustomProviderConfiguration(
+ credentials=credentials
+ )
+ ),
+ model_settings=[]
+ ),
+ provider_instance=provider_instance,
+ model_type_instance=model_type_instance,
+ )
+
+ model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
+
+ model_config = ModelConfigWithCredentialsEntity(
+ model='gpt-3.5-turbo',
+ provider='openai',
+ mode='chat',
+ credentials=credentials,
+ parameters={},
+ model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
+ provider_model_bundle=provider_model_bundle
+ )
+
+ # Mock db.session.close()
+ db.session.close = MagicMock()
+
+ node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
+
+ # execute node
+ result = node.run(pool)
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert 'sunny' in json.dumps(result.process_data)
+ assert 'what\'s the weather today?' in json.dumps(result.process_data)
\ No newline at end of file
diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
new file mode 100644
index 0000000000..056c78441d
--- /dev/null
+++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
@@ -0,0 +1,357 @@
+import json
+import os
+from unittest.mock import MagicMock
+
+import pytest
+
+from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
+from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
+from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
+from core.model_manager import ModelInstance
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
+from core.workflow.entities.node_entities import SystemVariable
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.nodes.base_node import UserFrom
+from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
+from extensions.ext_database import db
+from models.provider import ProviderType
+
+"""FOR MOCK FIXTURES, DO NOT REMOVE"""
+from models.workflow import WorkflowNodeExecutionStatus
+from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
+from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
+
+
+def get_mocked_fetch_model_config(
+ provider: str, model: str, mode: str,
+ credentials: dict,
+):
+ provider_instance = ModelProviderFactory().get_provider_instance(provider)
+ model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
+ provider_model_bundle = ProviderModelBundle(
+ configuration=ProviderConfiguration(
+ tenant_id='1',
+ provider=provider_instance.get_provider_schema(),
+ preferred_provider_type=ProviderType.CUSTOM,
+ using_provider_type=ProviderType.CUSTOM,
+ system_configuration=SystemConfiguration(
+ enabled=False
+ ),
+ custom_configuration=CustomConfiguration(
+ provider=CustomProviderConfiguration(
+ credentials=credentials
+ )
+ ),
+ model_settings=[]
+ ),
+ provider_instance=provider_instance,
+ model_type_instance=model_type_instance
+ )
+ model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model)
+ model_config = ModelConfigWithCredentialsEntity(
+ model=model,
+ provider=provider,
+ mode=mode,
+ credentials=credentials,
+ parameters={},
+ model_schema=model_type_instance.get_model_schema(model),
+ provider_model_bundle=provider_model_bundle
+ )
+
+ return MagicMock(return_value=tuple([model_instance, model_config]))
+
+@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+def test_function_calling_parameter_extractor(setup_openai_mock):
+ """
+ Test function calling for parameter extractor.
+ """
+ node = ParameterExtractorNode(
+ tenant_id='1',
+ app_id='1',
+ workflow_id='1',
+ user_id='1',
+ invoke_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
+ config={
+ 'id': 'llm',
+ 'data': {
+ 'title': '123',
+ 'type': 'parameter-extractor',
+ 'model': {
+ 'provider': 'openai',
+ 'name': 'gpt-3.5-turbo',
+ 'mode': 'chat',
+ 'completion_params': {}
+ },
+ 'query': ['sys', 'query'],
+ 'parameters': [{
+ 'name': 'location',
+ 'type': 'string',
+ 'description': 'location',
+ 'required': True
+ }],
+ 'instruction': '',
+ 'reasoning_mode': 'function_call',
+ 'memory': None,
+ }
+ }
+ )
+
+ node._fetch_model_config = get_mocked_fetch_model_config(
+ provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={
+ 'openai_api_key': os.environ.get('OPENAI_API_KEY')
+ }
+ )
+ db.session.close = MagicMock()
+
+ # construct variable pool
+ pool = VariablePool(system_variables={
+ SystemVariable.QUERY: 'what\'s the weather in SF',
+ SystemVariable.FILES: [],
+ SystemVariable.CONVERSATION_ID: 'abababa',
+ SystemVariable.USER_ID: 'aaa'
+ }, user_inputs={})
+
+ result = node.run(pool)
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs.get('location') == 'kawaii'
+ assert result.outputs.get('__reason') == None
+
+@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+def test_instructions(setup_openai_mock):
+ """
+ Test chat parameter extractor.
+ """
+ node = ParameterExtractorNode(
+ tenant_id='1',
+ app_id='1',
+ workflow_id='1',
+ user_id='1',
+ invoke_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
+ config={
+ 'id': 'llm',
+ 'data': {
+ 'title': '123',
+ 'type': 'parameter-extractor',
+ 'model': {
+ 'provider': 'openai',
+ 'name': 'gpt-3.5-turbo',
+ 'mode': 'chat',
+ 'completion_params': {}
+ },
+ 'query': ['sys', 'query'],
+ 'parameters': [{
+ 'name': 'location',
+ 'type': 'string',
+ 'description': 'location',
+ 'required': True
+ }],
+ 'reasoning_mode': 'function_call',
+ 'instruction': '{{#sys.query#}}',
+ 'memory': None,
+ }
+ }
+ )
+
+ node._fetch_model_config = get_mocked_fetch_model_config(
+ provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={
+ 'openai_api_key': os.environ.get('OPENAI_API_KEY')
+ }
+ )
+ db.session.close = MagicMock()
+
+ # construct variable pool
+ pool = VariablePool(system_variables={
+ SystemVariable.QUERY: 'what\'s the weather in SF',
+ SystemVariable.FILES: [],
+ SystemVariable.CONVERSATION_ID: 'abababa',
+ SystemVariable.USER_ID: 'aaa'
+ }, user_inputs={})
+
+ result = node.run(pool)
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs.get('location') == 'kawaii'
+ assert result.outputs.get('__reason') == None
+
+ process_data = result.process_data
+
+ process_data.get('prompts')
+
+ for prompt in process_data.get('prompts'):
+ if prompt.get('role') == 'system':
+ assert 'what\'s the weather in SF' in prompt.get('text')
+
+@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
+def test_chat_parameter_extractor(setup_anthropic_mock):
+ """
+ Test chat parameter extractor.
+ """
+ node = ParameterExtractorNode(
+ tenant_id='1',
+ app_id='1',
+ workflow_id='1',
+ user_id='1',
+ invoke_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
+ config={
+ 'id': 'llm',
+ 'data': {
+ 'title': '123',
+ 'type': 'parameter-extractor',
+ 'model': {
+ 'provider': 'anthropic',
+ 'name': 'claude-2',
+ 'mode': 'chat',
+ 'completion_params': {}
+ },
+ 'query': ['sys', 'query'],
+ 'parameters': [{
+ 'name': 'location',
+ 'type': 'string',
+ 'description': 'location',
+ 'required': True
+ }],
+ 'reasoning_mode': 'prompt',
+ 'instruction': '',
+ 'memory': None,
+ }
+ }
+ )
+
+ node._fetch_model_config = get_mocked_fetch_model_config(
+ provider='anthropic', model='claude-2', mode='chat', credentials={
+ 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
+ }
+ )
+ db.session.close = MagicMock()
+
+ # construct variable pool
+ pool = VariablePool(system_variables={
+ SystemVariable.QUERY: 'what\'s the weather in SF',
+ SystemVariable.FILES: [],
+ SystemVariable.CONVERSATION_ID: 'abababa',
+ SystemVariable.USER_ID: 'aaa'
+ }, user_inputs={})
+
+ result = node.run(pool)
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs.get('location') == ''
+ assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.'
+ prompts = result.process_data.get('prompts')
+
+ for prompt in prompts:
+ if prompt.get('role') == 'user':
+ if '' in prompt.get('text'):
+ assert '\n{"type": "object"' in prompt.get('text')
+
+@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True)
+def test_completion_parameter_extractor(setup_openai_mock):
+ """
+ Test completion parameter extractor.
+ """
+ node = ParameterExtractorNode(
+ tenant_id='1',
+ app_id='1',
+ workflow_id='1',
+ user_id='1',
+ invoke_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
+ config={
+ 'id': 'llm',
+ 'data': {
+ 'title': '123',
+ 'type': 'parameter-extractor',
+ 'model': {
+ 'provider': 'openai',
+ 'name': 'gpt-3.5-turbo-instruct',
+ 'mode': 'completion',
+ 'completion_params': {}
+ },
+ 'query': ['sys', 'query'],
+ 'parameters': [{
+ 'name': 'location',
+ 'type': 'string',
+ 'description': 'location',
+ 'required': True
+ }],
+ 'reasoning_mode': 'prompt',
+ 'instruction': '{{#sys.query#}}',
+ 'memory': None,
+ }
+ }
+ )
+
+ node._fetch_model_config = get_mocked_fetch_model_config(
+ provider='openai', model='gpt-3.5-turbo-instruct', mode='completion', credentials={
+ 'openai_api_key': os.environ.get('OPENAI_API_KEY')
+ }
+ )
+ db.session.close = MagicMock()
+
+ # construct variable pool
+ pool = VariablePool(system_variables={
+ SystemVariable.QUERY: 'what\'s the weather in SF',
+ SystemVariable.FILES: [],
+ SystemVariable.CONVERSATION_ID: 'abababa',
+ SystemVariable.USER_ID: 'aaa'
+ }, user_inputs={})
+
+ result = node.run(pool)
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs.get('location') == ''
+ assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.'
+ assert len(result.process_data.get('prompts')) == 1
+ assert 'SF' in result.process_data.get('prompts')[0].get('text')
+
+def test_extract_json_response():
+ """
+ Test extract json response.
+ """
+
+ node = ParameterExtractorNode(
+ tenant_id='1',
+ app_id='1',
+ workflow_id='1',
+ user_id='1',
+ invoke_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
+ config={
+ 'id': 'llm',
+ 'data': {
+ 'title': '123',
+ 'type': 'parameter-extractor',
+ 'model': {
+ 'provider': 'openai',
+ 'name': 'gpt-3.5-turbo-instruct',
+ 'mode': 'completion',
+ 'completion_params': {}
+ },
+ 'query': ['sys', 'query'],
+ 'parameters': [{
+ 'name': 'location',
+ 'type': 'string',
+ 'description': 'location',
+ 'required': True
+ }],
+ 'reasoning_mode': 'prompt',
+ 'instruction': '{{#sys.query#}}',
+ 'memory': None,
+ }
+ }
+ )
+
+ result = node._extract_complete_json_response("""
+ uwu{ovo}
+ {
+ "location": "kawaii"
+ }
+ hello world.
+ """)
+
+ assert result['location'] == 'kawaii'
\ No newline at end of file
diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
index 4a31334056..02999bf0a2 100644
--- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py
+++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
@@ -1,5 +1,6 @@
import pytest
+from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
@@ -15,6 +16,7 @@ def test_execute_code(setup_code_executor_mock):
app_id='1',
workflow_id='1',
user_id='1',
+ invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.END_USER,
config={
'id': '1',
diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py
index 4bbd4ccee7..fffd074457 100644
--- a/api/tests/integration_tests/workflow/nodes/test_tool.py
+++ b/api/tests/integration_tests/workflow/nodes/test_tool.py
@@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.tool.tool_node import ToolNode
from models.workflow import WorkflowNodeExecutionStatus
@@ -13,7 +14,8 @@ def test_tool_variable_invoke():
app_id='1',
workflow_id='1',
user_id='1',
- user_from=InvokeFrom.WEB_APP,
+ invoke_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
config={
'id': '1',
'data': {
@@ -51,7 +53,8 @@ def test_tool_mixed_invoke():
app_id='1',
workflow_id='1',
user_id='1',
- user_from=InvokeFrom.WEB_APP,
+ invoke_from=InvokeFrom.WEB_APP,
+ user_from=UserFrom.ACCOUNT,
config={
'id': '1',
'data': {
diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
new file mode 100644
index 0000000000..9de268d762
--- /dev/null
+++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
@@ -0,0 +1,77 @@
+from unittest.mock import MagicMock
+
+from core.app.entities.app_invoke_entities import (
+ ModelConfigWithCredentialsEntity,
+)
+from core.entities.provider_configuration import ProviderModelBundle
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ SystemPromptMessage,
+ ToolPromptMessage,
+ UserPromptMessage,
+)
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
+from models.model import Conversation
+
+
+def test_get_prompt():
+ prompt_messages = [
+ SystemPromptMessage(content='System Template'),
+ UserPromptMessage(content='User Query'),
+ ]
+ history_messages = [
+ SystemPromptMessage(content='System Prompt 1'),
+ UserPromptMessage(content='User Prompt 1'),
+ AssistantPromptMessage(content='Assistant Thought 1'),
+ ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'),
+ ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'),
+ SystemPromptMessage(content='System Prompt 2'),
+ UserPromptMessage(content='User Prompt 2'),
+ AssistantPromptMessage(content='Assistant Thought 2'),
+ ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'),
+ ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'),
+ UserPromptMessage(content='User Prompt 3'),
+ AssistantPromptMessage(content='Assistant Thought 3'),
+ ]
+
+ # use message number instead of token for testing
+ def side_effect_get_num_tokens(*args):
+ return len(args[2])
+ large_language_model_mock = MagicMock(spec=LargeLanguageModel)
+ large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
+
+ provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
+ provider_model_bundle_mock.model_type_instance = large_language_model_mock
+
+ model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
+ model_config_mock.model = 'openai'
+ model_config_mock.credentials = {}
+ model_config_mock.provider_model_bundle = provider_model_bundle_mock
+
+ memory = TokenBufferMemory(
+ conversation=Conversation(),
+ model_instance=model_config_mock
+ )
+
+ transform = AgentHistoryPromptTransform(
+ model_config=model_config_mock,
+ prompt_messages=prompt_messages,
+ history_messages=history_messages,
+ memory=memory
+ )
+
+ max_token_limit = 5
+ transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
+ result = transform.get_prompt()
+
+ assert len(result) <= max_token_limit
+ assert len(result) == 4
+
+ max_token_limit = 20
+ transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
+ result = transform.get_prompt()
+
+ assert len(result) <= max_token_limit
+ assert len(result) == 12
diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py
index 40f5be8af9..2bcc6f4292 100644
--- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py
@@ -1,9 +1,10 @@
from unittest.mock import MagicMock
from core.app.app_config.entities import ModelConfigEntity
-from core.entities.provider_configuration import ProviderModelBundle
+from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.model_runtime.entities.message_entities import UserPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
+from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_transform import PromptTransform
@@ -22,8 +23,16 @@ def test__calculate_rest_token():
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens.return_value = 6
+ provider_mock = MagicMock(spec=ProviderEntity)
+ provider_mock.provider = 'openai'
+
+ provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
+ provider_configuration_mock.provider = provider_mock
+ provider_configuration_mock.model_settings = None
+
provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
provider_model_bundle_mock.model_type_instance = large_language_model_mock
+ provider_model_bundle_mock.configuration = provider_configuration_mock
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.model = 'gpt-4'
diff --git a/api/tests/unit_tests/core/rag/extractor/__init__.py b/api/tests/unit_tests/core/rag/extractor/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py
new file mode 100644
index 0000000000..b231fe479b
--- /dev/null
+++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py
@@ -0,0 +1,102 @@
+from unittest import mock
+
+from core.rag.extractor import notion_extractor
+
+user_id = "user1"
+database_id = "database1"
+page_id = "page1"
+
+
+extractor = notion_extractor.NotionExtractor(
+ notion_workspace_id='x',
+ notion_obj_id='x',
+ notion_page_type='page',
+ tenant_id='x',
+ notion_access_token='x')
+
+
+def _generate_page(page_title: str):
+ return {
+ "object": "page",
+ "id": page_id,
+ "properties": {
+ "Page": {
+ "type": "title",
+ "title": [
+ {
+ "type": "text",
+ "text": {"content": page_title},
+ "plain_text": page_title
+ }
+ ]
+ }
+ }
+ }
+
+
+def _generate_block(block_id: str, block_type: str, block_text: str):
+ return {
+ "object": "block",
+ "id": block_id,
+ "parent": {
+ "type": "page_id",
+ "page_id": page_id
+ },
+ "type": block_type,
+ "has_children": False,
+ block_type: {
+ "rich_text": [
+ {
+ "type": "text",
+ "text": {"content": block_text},
+ "plain_text": block_text,
+ }]
+ }
+ }
+
+
+def _mock_response(data):
+ response = mock.Mock()
+ response.status_code = 200
+ response.json.return_value = data
+ return response
+
+
+def _remove_multiple_new_lines(text):
+ while '\n\n' in text:
+ text = text.replace("\n\n", "\n")
+ return text.strip()
+
+
+def test_notion_page(mocker):
+ texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
+ mocked_notion_page = {
+ "object": "list",
+ "results": [
+ _generate_block("b1", "heading_1", texts[0]),
+ _generate_block("b2", "heading_2", texts[1]),
+ _generate_block("b3", "paragraph", texts[2]),
+ _generate_block("b4", "heading_3", texts[3])
+ ],
+ "next_cursor": None
+ }
+ mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page))
+
+ page_docs = extractor._load_data_as_documents(page_id, "page")
+ assert len(page_docs) == 1
+ content = _remove_multiple_new_lines(page_docs[0].page_content)
+ assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1'
+
+
+def test_notion_database(mocker):
+ page_title_list = ["page1", "page2", "page3"]
+ mocked_notion_database = {
+ "object": "list",
+ "results": [_generate_page(i) for i in page_title_list],
+ "next_cursor": None
+ }
+ mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database))
+ database_docs = extractor._load_data_as_documents(database_id, "database")
+ assert len(database_docs) == 1
+ content = _remove_multiple_new_lines(database_docs[0].page_content)
+ assert content == '\n'.join([f'Page:{i}' for i in page_title_list])
diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py
new file mode 100644
index 0000000000..3024a54a4d
--- /dev/null
+++ b/api/tests/unit_tests/core/test_model_manager.py
@@ -0,0 +1,77 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from core.entities.provider_entities import ModelLoadBalancingConfiguration
+from core.model_manager import LBModelManager
+from core.model_runtime.entities.model_entities import ModelType
+
+
+@pytest.fixture
+def lb_model_manager():
+ load_balancing_configs = [
+ ModelLoadBalancingConfiguration(
+ id='id1',
+ name='__inherit__',
+ credentials={}
+ ),
+ ModelLoadBalancingConfiguration(
+ id='id2',
+ name='first',
+ credentials={"openai_api_key": "fake_key"}
+ ),
+ ModelLoadBalancingConfiguration(
+ id='id3',
+ name='second',
+ credentials={"openai_api_key": "fake_key"}
+ )
+ ]
+
+ lb_model_manager = LBModelManager(
+ tenant_id='tenant_id',
+ provider='openai',
+ model_type=ModelType.LLM,
+ model='gpt-4',
+ load_balancing_configs=load_balancing_configs,
+ managed_credentials={"openai_api_key": "fake_key"}
+ )
+
+ lb_model_manager.cooldown = MagicMock(return_value=None)
+
+ def is_cooldown(config: ModelLoadBalancingConfiguration):
+ if config.id == 'id1':
+ return True
+
+ return False
+
+ lb_model_manager.in_cooldown = MagicMock(side_effect=is_cooldown)
+
+ return lb_model_manager
+
+
+def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
+ assert len(lb_model_manager._load_balancing_configs) == 3
+
+ config1 = lb_model_manager._load_balancing_configs[0]
+ config2 = lb_model_manager._load_balancing_configs[1]
+ config3 = lb_model_manager._load_balancing_configs[2]
+
+ assert lb_model_manager.in_cooldown(config1) is True
+ assert lb_model_manager.in_cooldown(config2) is False
+ assert lb_model_manager.in_cooldown(config3) is False
+
+ start_index = 0
+ def incr(key):
+ nonlocal start_index
+ start_index += 1
+ return start_index
+
+ mocker.patch('redis.Redis.incr', side_effect=incr)
+ mocker.patch('redis.Redis.set', return_value=None)
+ mocker.patch('redis.Redis.expire', return_value=None)
+
+ config = lb_model_manager.fetch_next()
+ assert config == config2
+
+ config = lb_model_manager.fetch_next()
+ assert config == config3
diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py
new file mode 100644
index 0000000000..072b6f100f
--- /dev/null
+++ b/api/tests/unit_tests/core/test_provider_manager.py
@@ -0,0 +1,183 @@
+from core.entities.provider_entities import ModelSettings
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.model_providers import model_provider_factory
+from core.provider_manager import ProviderManager
+from models.provider import LoadBalancingModelConfig, ProviderModelSetting
+
+
+def test__to_model_settings(mocker):
+ # Get all provider entities
+ provider_entities = model_provider_factory.get_providers()
+
+ provider_entity = None
+ for provider in provider_entities:
+ if provider.provider == 'openai':
+ provider_entity = provider
+
+ # Mocking the inputs
+ provider_model_settings = [ProviderModelSetting(
+ id='id',
+ tenant_id='tenant_id',
+ provider_name='openai',
+ model_name='gpt-4',
+ model_type='text-generation',
+ enabled=True,
+ load_balancing_enabled=True
+ )]
+ load_balancing_model_configs = [
+ LoadBalancingModelConfig(
+ id='id1',
+ tenant_id='tenant_id',
+ provider_name='openai',
+ model_name='gpt-4',
+ model_type='text-generation',
+ name='__inherit__',
+ encrypted_config=None,
+ enabled=True
+ ),
+ LoadBalancingModelConfig(
+ id='id2',
+ tenant_id='tenant_id',
+ provider_name='openai',
+ model_name='gpt-4',
+ model_type='text-generation',
+ name='first',
+ encrypted_config='{"openai_api_key": "fake_key"}',
+ enabled=True
+ )
+ ]
+
+ mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
+
+ provider_manager = ProviderManager()
+
+ # Running the method
+ result = provider_manager._to_model_settings(
+ provider_entity,
+ provider_model_settings,
+ load_balancing_model_configs
+ )
+
+ # Asserting that the result is as expected
+ assert len(result) == 1
+ assert isinstance(result[0], ModelSettings)
+ assert result[0].model == 'gpt-4'
+ assert result[0].model_type == ModelType.LLM
+ assert result[0].enabled is True
+ assert len(result[0].load_balancing_configs) == 2
+ assert result[0].load_balancing_configs[0].name == '__inherit__'
+ assert result[0].load_balancing_configs[1].name == 'first'
+
+
+def test__to_model_settings_only_one_lb(mocker):
+ # Get all provider entities
+ provider_entities = model_provider_factory.get_providers()
+
+ provider_entity = None
+ for provider in provider_entities:
+ if provider.provider == 'openai':
+ provider_entity = provider
+
+ # Mocking the inputs
+ provider_model_settings = [ProviderModelSetting(
+ id='id',
+ tenant_id='tenant_id',
+ provider_name='openai',
+ model_name='gpt-4',
+ model_type='text-generation',
+ enabled=True,
+ load_balancing_enabled=True
+ )]
+ load_balancing_model_configs = [
+ LoadBalancingModelConfig(
+ id='id1',
+ tenant_id='tenant_id',
+ provider_name='openai',
+ model_name='gpt-4',
+ model_type='text-generation',
+ name='__inherit__',
+ encrypted_config=None,
+ enabled=True
+ )
+ ]
+
+ mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
+
+ provider_manager = ProviderManager()
+
+ # Running the method
+ result = provider_manager._to_model_settings(
+ provider_entity,
+ provider_model_settings,
+ load_balancing_model_configs
+ )
+
+ # Asserting that the result is as expected
+ assert len(result) == 1
+ assert isinstance(result[0], ModelSettings)
+ assert result[0].model == 'gpt-4'
+ assert result[0].model_type == ModelType.LLM
+ assert result[0].enabled is True
+ assert len(result[0].load_balancing_configs) == 0
+
+
+def test__to_model_settings_lb_disabled(mocker):
+ # Get all provider entities
+ provider_entities = model_provider_factory.get_providers()
+
+ provider_entity = None
+ for provider in provider_entities:
+ if provider.provider == 'openai':
+ provider_entity = provider
+
+ # Mocking the inputs
+ provider_model_settings = [ProviderModelSetting(
+ id='id',
+ tenant_id='tenant_id',
+ provider_name='openai',
+ model_name='gpt-4',
+ model_type='text-generation',
+ enabled=True,
+ load_balancing_enabled=False
+ )]
+ load_balancing_model_configs = [
+ LoadBalancingModelConfig(
+ id='id1',
+ tenant_id='tenant_id',
+ provider_name='openai',
+ model_name='gpt-4',
+ model_type='text-generation',
+ name='__inherit__',
+ encrypted_config=None,
+ enabled=True
+ ),
+ LoadBalancingModelConfig(
+ id='id2',
+ tenant_id='tenant_id',
+ provider_name='openai',
+ model_name='gpt-4',
+ model_type='text-generation',
+ name='first',
+ encrypted_config='{"openai_api_key": "fake_key"}',
+ enabled=True
+ )
+ ]
+
+ mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
+
+ provider_manager = ProviderManager()
+
+ # Running the method
+ result = provider_manager._to_model_settings(
+ provider_entity,
+ provider_model_settings,
+ load_balancing_model_configs
+ )
+
+ # Asserting that the result is as expected
+ assert len(result) == 1
+ assert isinstance(result[0], ModelSettings)
+ assert result[0].model == 'gpt-4'
+ assert result[0].model_type == ModelType.LLM
+ assert result[0].enabled is True
+ assert len(result[0].load_balancing_configs) == 0
diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py
new file mode 100644
index 0000000000..9addeeadca
--- /dev/null
+++ b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py
@@ -0,0 +1,56 @@
+import pytest
+
+from core.tools.entities.tool_entities import ToolParameter
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
+
+
+def test_get_parameter_type():
+ assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == 'string'
+ assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == 'string'
+ assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == 'boolean'
+ assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == 'number'
+ with pytest.raises(ValueError):
+ ToolParameterConverter.get_parameter_type('unsupported_type')
+
+
+def test_cast_parameter_by_type():
+ # string
+ assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.STRING) == 'test'
+ assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == '1'
+ assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == '1.0'
+ assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ''
+
+ # secret input
+ assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SECRET_INPUT) == 'test'
+ assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == '1'
+ assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == '1.0'
+ assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ''
+
+ # select
+ assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SELECT) == 'test'
+ assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == '1'
+ assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == '1.0'
+ assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ''
+
+ # boolean
+ true_values = [True, 'True', 'true', '1', 'YES', 'Yes', 'yes', 'y', 'something']
+ for value in true_values:
+ assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True
+
+ false_values = [False, 'False', 'false', '0', 'NO', 'No', 'no', 'n', None, '']
+ for value in false_values:
+ assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False
+
+ # number
+ assert ToolParameterConverter.cast_parameter_by_type('1', ToolParameter.ToolParameterType.NUMBER) == 1
+ assert ToolParameterConverter.cast_parameter_by_type('1.0', ToolParameter.ToolParameterType.NUMBER) == 1.0
+ assert ToolParameterConverter.cast_parameter_by_type('-1.0', ToolParameter.ToolParameterType.NUMBER) == -1.0
+ assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1
+ assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0
+ assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0
+ assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
+
+ # unknown
+ assert ToolParameterConverter.cast_parameter_by_type('1', 'unknown_type') == '1'
+ assert ToolParameterConverter.cast_parameter_by_type(1, 'unknown_type') == '1'
+ assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py
index cf21401eb2..102711b4b6 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py
@@ -1,5 +1,6 @@
from unittest.mock import MagicMock
+from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.answer.answer_node import AnswerNode
@@ -15,6 +16,7 @@ def test_execute_answer():
workflow_id='1',
user_id='1',
user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'answer',
'data': {
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
index 99413540c5..6860b2fd97 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
@@ -1,5 +1,6 @@
from unittest.mock import MagicMock
+from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
@@ -15,6 +16,7 @@ def test_execute_if_else_result_true():
workflow_id='1',
user_id='1',
user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'if-else',
'data': {
@@ -155,6 +157,7 @@ def test_execute_if_else_result_false():
workflow_id='1',
user_id='1',
user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'if-else',
'data': {
diff --git a/api/tests/unit_tests/libs/test_pandas.py b/api/tests/unit_tests/libs/test_pandas.py
new file mode 100644
index 0000000000..bbc372ed61
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_pandas.py
@@ -0,0 +1,62 @@
+import pandas as pd
+
+
+def test_pandas_csv(tmp_path, monkeypatch):
+ monkeypatch.chdir(tmp_path)
+ data = {'col1': [1, 2.2, -3.3, 4.0, 5],
+ 'col2': ['A', 'B', 'C', 'D', 'E']}
+ df1 = pd.DataFrame(data)
+
+ # write to csv file
+ csv_file_path = tmp_path.joinpath('example.csv')
+ df1.to_csv(csv_file_path, index=False)
+
+ # read from csv file
+ df2 = pd.read_csv(csv_file_path, on_bad_lines='skip')
+ assert df2[df2.columns[0]].to_list() == data['col1']
+ assert df2[df2.columns[1]].to_list() == data['col2']
+
+
+def test_pandas_xlsx(tmp_path, monkeypatch):
+ monkeypatch.chdir(tmp_path)
+ data = {'col1': [1, 2.2, -3.3, 4.0, 5],
+ 'col2': ['A', 'B', 'C', 'D', 'E']}
+ df1 = pd.DataFrame(data)
+
+ # write to xlsx file
+ xlsx_file_path = tmp_path.joinpath('example.xlsx')
+ df1.to_excel(xlsx_file_path, index=False)
+
+ # read from xlsx file
+ df2 = pd.read_excel(xlsx_file_path)
+ assert df2[df2.columns[0]].to_list() == data['col1']
+ assert df2[df2.columns[1]].to_list() == data['col2']
+
+
+def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch):
+ monkeypatch.chdir(tmp_path)
+ data1 = {'col1': [1, 2, 3, 4, 5],
+ 'col2': ['A', 'B', 'C', 'D', 'E']}
+ df1 = pd.DataFrame(data1)
+
+ data2 = {'col1': [6, 7, 8, 9, 10],
+ 'col2': ['F', 'G', 'H', 'I', 'J']}
+ df2 = pd.DataFrame(data2)
+
+ # write to xlsx file with sheets
+ xlsx_file_path = tmp_path.joinpath('example_with_sheets.xlsx')
+ sheet1 = 'Sheet1'
+ sheet2 = 'Sheet2'
+ with pd.ExcelWriter(xlsx_file_path) as excel_writer:
+ df1.to_excel(excel_writer, sheet_name=sheet1, index=False)
+ df2.to_excel(excel_writer, sheet_name=sheet2, index=False)
+
+ # read from xlsx file with sheets
+ with pd.ExcelFile(xlsx_file_path) as excel_file:
+ df1 = pd.read_excel(excel_file, sheet_name=sheet1)
+ assert df1[df1.columns[0]].to_list() == data1['col1']
+ assert df1[df1.columns[1]].to_list() == data1['col2']
+
+ df2 = pd.read_excel(excel_file, sheet_name=sheet2)
+ assert df2[df2.columns[0]].to_list() == data2['col1']
+ assert df2[df2.columns[1]].to_list() == data2['col2']
diff --git a/api/tests/unit_tests/libs/test_yarl.py b/api/tests/unit_tests/libs/test_yarl.py
new file mode 100644
index 0000000000..75a5344126
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_yarl.py
@@ -0,0 +1,23 @@
+import pytest
+from yarl import URL
+
+
+def test_yarl_urls():
+ expected_1 = 'https://dify.ai/api'
+ assert str(URL('https://dify.ai') / 'api') == expected_1
+ assert str(URL('https://dify.ai/') / 'api') == expected_1
+
+ expected_2 = 'http://dify.ai:12345/api'
+ assert str(URL('http://dify.ai:12345') / 'api') == expected_2
+ assert str(URL('http://dify.ai:12345/') / 'api') == expected_2
+
+ expected_3 = 'https://dify.ai/api/v1'
+ assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3
+ assert str(URL('https://dify.ai') / 'api/v1') == expected_3
+ assert str(URL('https://dify.ai/') / 'api/v1') == expected_3
+ assert str(URL('https://dify.ai/api') / 'v1') == expected_3
+ assert str(URL('https://dify.ai/api/') / 'v1') == expected_3
+
+ with pytest.raises(ValueError) as e1:
+ str(URL('https://dify.ai') / '/api')
+ assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"
diff --git a/api/tests/unit_tests/utils/__init__.py b/api/tests/unit_tests/utils/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py
new file mode 100644
index 0000000000..c389461454
--- /dev/null
+++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py
@@ -0,0 +1,34 @@
+from textwrap import dedent
+
+import pytest
+
+from core.helper.position_helper import get_position_map
+
+
+@pytest.fixture
+def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
+ monkeypatch.chdir(tmp_path)
+ tmp_path.joinpath("example_positions.yaml").write_text(dedent(
+ """\
+ - first
+ - second
+ # - commented
+ - third
+
+ - 9999999999999
+ - forth
+ """))
+ return str(tmp_path)
+
+
+def test_position_helper(prepare_example_positions_yaml):
+ position_map = get_position_map(
+ folder_path=prepare_example_positions_yaml,
+ file_name='example_positions.yaml')
+ assert len(position_map) == 4
+ assert position_map == {
+ 'first': 0,
+ 'second': 1,
+ 'third': 2,
+ 'forth': 3,
+ }
diff --git a/api/tests/unit_tests/utils/yaml/__init__.py b/api/tests/unit_tests/utils/yaml/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py
new file mode 100644
index 0000000000..446588cde1
--- /dev/null
+++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py
@@ -0,0 +1,74 @@
+from textwrap import dedent
+
+import pytest
+from yaml import YAMLError
+
+from core.tools.utils.yaml_utils import load_yaml_file
+
+EXAMPLE_YAML_FILE = 'example_yaml.yaml'
+INVALID_YAML_FILE = 'invalid_yaml.yaml'
+NON_EXISTING_YAML_FILE = 'non_existing_file.yaml'
+
+
+@pytest.fixture
+def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
+ monkeypatch.chdir(tmp_path)
+ file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
+ file_path.write_text(dedent(
+ """\
+ address:
+ city: Example City
+ country: Example Country
+ age: 30
+ gender: male
+ languages:
+ - Python
+ - Java
+ - C++
+ empty_key:
+ """))
+ return str(file_path)
+
+
+@pytest.fixture
+def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
+ monkeypatch.chdir(tmp_path)
+ file_path = tmp_path.joinpath(INVALID_YAML_FILE)
+ file_path.write_text(dedent(
+ """\
+ address:
+ city: Example City
+ country: Example Country
+ age: 30
+ gender: male
+ languages:
+ - Python
+ - Java
+ - C++
+ """))
+ return str(file_path)
+
+
+def test_load_yaml_non_existing_file():
+ assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
+ assert load_yaml_file(file_path='') == {}
+
+
+def test_load_valid_yaml_file(prepare_example_yaml_file):
+ yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
+ assert len(yaml_data) > 0
+ assert yaml_data['age'] == 30
+ assert yaml_data['gender'] == 'male'
+ assert yaml_data['address']['city'] == 'Example City'
+ assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'}
+ assert yaml_data.get('empty_key') is None
+ assert yaml_data.get('non_existed_key') is None
+
+
+def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
+ # yaml syntax error
+ with pytest.raises(YAMLError):
+ load_yaml_file(file_path=prepare_invalid_yaml_file)
+
+ # ignore error
+ assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}
diff --git a/dev/reformat b/dev/reformat
index ebee1efb40..844bfbef58 100755
--- a/dev/reformat
+++ b/dev/reformat
@@ -9,7 +9,7 @@ if ! command -v ruff &> /dev/null; then
fi
# run ruff linter
-ruff check --fix ./api
+ruff check --fix --preview ./api
# env files linting relies on `dotenv-linter` in path
if ! command -v dotenv-linter &> /dev/null; then
diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml
index 6bf45da9e0..38760901b1 100644
--- a/docker/docker-compose.middleware.yaml
+++ b/docker/docker-compose.middleware.yaml
@@ -53,20 +53,39 @@ services:
# The DifySandbox
sandbox:
- image: langgenius/dify-sandbox:0.1.0
+ image: langgenius/dify-sandbox:0.2.1
restart: always
- cap_add:
- # Why is sys_admin permission needed?
- # https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-sys_admin-permission-needed
- - SYS_ADMIN
environment:
# The DifySandbox configurations
+ # Make sure you are changing this key for your deployment with a strong key.
+ # You can generate a strong key using `openssl rand -base64 42`.
API_KEY: dify-sandbox
GIN_MODE: 'release'
WORKER_TIMEOUT: 15
- ports:
- - "8194:8194"
+ ENABLE_NETWORK: 'true'
+ HTTP_PROXY: 'http://ssrf_proxy:3128'
+ HTTPS_PROXY: 'http://ssrf_proxy:3128'
+ SANDBOX_PORT: 8194
+ volumes:
+ - ./volumes/sandbox/dependencies:/dependencies
+ networks:
+ - ssrf_proxy_network
+ # ssrf_proxy server
+ # for more information, please refer to
+ # https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-ssrf_proxy-needed
+ ssrf_proxy:
+ image: ubuntu/squid:latest
+ restart: always
+ ports:
+ - "3128:3128"
+ - "8194:8194"
+ volumes:
+ # pls clearly modify the squid.conf file to fit your network environment.
+ - ./volumes/ssrf_proxy/squid.conf:/etc/squid/squid.conf
+ networks:
+ - ssrf_proxy_network
+ - default
# Qdrant vector store.
# uncomment to use qdrant as vector store.
# (if uncommented, you need to comment out the weaviate service above,
@@ -81,3 +100,10 @@ services:
# ports:
# - "6333:6333"
# - "6334:6334"
+
+
+networks:
+ # create a network between sandbox, api and ssrf_proxy, and can not access outside.
+ ssrf_proxy_network:
+ driver: bridge
+ internal: true
diff --git a/docker/docker-compose.pgvector.yaml b/docker/docker-compose.pgvector.yaml
new file mode 100644
index 0000000000..b584880abf
--- /dev/null
+++ b/docker/docker-compose.pgvector.yaml
@@ -0,0 +1,24 @@
+version: '3'
+services:
+ # Qdrant vector store.
+ pgvector:
+ image: pgvector/pgvector:pg16
+ restart: always
+ environment:
+ PGUSER: postgres
+ # The password for the default postgres user.
+ POSTGRES_PASSWORD: difyai123456
+ # The name of the default postgres database.
+ POSTGRES_DB: dify
+ # postgres data directory
+ PGDATA: /var/lib/postgresql/data/pgdata
+ volumes:
+ - ./volumes/pgvector/data:/var/lib/postgresql/data
+ # uncomment to expose db(postgresql) port to host
+ ports:
+ - "5433:5432"
+ healthcheck:
+ test: [ "CMD", "pg_isready" ]
+ interval: 1s
+ timeout: 3s
+ retries: 30
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 706f7d70e1..c3b5430514 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -2,13 +2,15 @@ version: '3'
services:
# API service
api:
- image: langgenius/dify-api:0.6.6
+ image: langgenius/dify-api:0.6.10
restart: always
environment:
# Startup mode, 'api' starts the API server.
MODE: api
# The log level for the application. Supported values are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`
LOG_LEVEL: INFO
+ # enable DEBUG mode to output more logs
+ # DEBUG : true
# A secret key that is used for securely signing the session cookie and encrypting sensitive information on the database. You can generate a strong key using `openssl rand -base64 42`.
SECRET_KEY: sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U
# The base URL of console application web frontend, refers to the Console base URL of WEB service if console domain is
@@ -34,6 +36,9 @@ services:
# used to display File preview or download Url to the front-end or as Multi-model inputs;
# Url is signed and has expiration time.
FILES_URL: ''
+ # File Access Time specifies a time interval in seconds for the file to be accessed.
+ # The default value is 300 seconds.
+ FILES_ACCESS_TIMEOUT: 300
# When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed.
MIGRATION_ENABLED: 'true'
# The configurations of postgres database connection.
@@ -88,6 +93,7 @@ services:
AZURE_BLOB_ACCOUNT_URL: 'https://.blob.core.windows.net'
# The Google storage configurations, only available when STORAGE_TYPE is `google-storage`.
GOOGLE_STORAGE_BUCKET_NAME: 'yout-bucket-name'
+ # if you want to use Application Default Credentials, you can leave GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 empty.
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: 'your-google-service-account-json-base64-string'
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
VECTOR_STORE: weaviate
@@ -122,15 +128,28 @@ services:
RELYT_USER: postgres
RELYT_PASSWORD: difyai123456
RELYT_DATABASE: postgres
+ # pgvector configurations
+ PGVECTOR_HOST: pgvector
+ PGVECTOR_PORT: 5432
+ PGVECTOR_USER: postgres
+ PGVECTOR_PASSWORD: difyai123456
+ PGVECTOR_DATABASE: dify
+ # tidb vector configurations
+ TIDB_VECTOR_HOST: tidb
+ TIDB_VECTOR_PORT: 4000
+ TIDB_VECTOR_USER: xxx.root
+ TIDB_VECTOR_PASSWORD: xxxxxx
+ TIDB_VECTOR_DATABASE: dify
# Mail configuration, support: resend, smtp
MAIL_TYPE: ''
# default send from email address, if not specified
MAIL_DEFAULT_SEND_FROM: 'YOUR EMAIL FROM (eg: no-reply )'
SMTP_SERVER: ''
- SMTP_PORT: 587
+ SMTP_PORT: 465
SMTP_USERNAME: ''
SMTP_PASSWORD: ''
SMTP_USE_TLS: 'true'
+ SMTP_OPPORTUNISTIC_TLS: 'false'
# the api-key for resend (https://resend.com)
RESEND_API_KEY: ''
RESEND_API_URL: https://api.resend.com
@@ -155,6 +174,11 @@ services:
CODE_MAX_STRING_ARRAY_LENGTH: 30
CODE_MAX_OBJECT_ARRAY_LENGTH: 30
CODE_MAX_NUMBER_ARRAY_LENGTH: 1000
+ # SSRF Proxy server
+ SSRF_PROXY_HTTP_URL: 'http://ssrf_proxy:3128'
+ SSRF_PROXY_HTTPS_URL: 'http://ssrf_proxy:3128'
+ # Indexing configuration
+ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: 1000
depends_on:
- db
- redis
@@ -164,13 +188,17 @@ services:
# uncomment to expose dify-api port to host
# ports:
# - "5001:5001"
+ networks:
+ - ssrf_proxy_network
+ - default
# worker service
# The Celery worker for processing the queue.
worker:
- image: langgenius/dify-api:0.6.6
+ image: langgenius/dify-api:0.6.10
restart: always
environment:
+ CONSOLE_WEB_URL: ''
# Startup mode, 'worker' starts the Celery worker for processing the queue.
MODE: worker
@@ -197,7 +225,7 @@ services:
REDIS_USE_SSL: 'false'
# The configurations of celery broker.
CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1
- # The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob`, Default: `local`
+ # The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob` and `google-storage`, Default: `local`
STORAGE_TYPE: local
STORAGE_LOCAL_PATH: storage
# The S3 storage configurations, only available when STORAGE_TYPE is `s3`.
@@ -211,7 +239,11 @@ services:
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
AZURE_BLOB_ACCOUNT_URL: 'https://.blob.core.windows.net'
- # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
+ # The Google storage configurations, only available when STORAGE_TYPE is `google-storage`.
+ GOOGLE_STORAGE_BUCKET_NAME: 'yout-bucket-name'
+ # if you want to use Application Default Credentials, you can leave GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 empty.
+ GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: 'your-google-service-account-json-base64-string'
+ # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`, `pgvector`.
VECTOR_STORE: weaviate
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
WEAVIATE_ENDPOINT: http://weaviate:8080
@@ -221,7 +253,7 @@ services:
QDRANT_URL: http://qdrant:6333
# The Qdrant API key.
QDRANT_API_KEY: difyai123456
- # The Qdrant clinet timeout setting.
+ # The Qdrant client timeout setting.
QDRANT_CLIENT_TIMEOUT: 20
# The Qdrant client enable gRPC mode.
QDRANT_GRPC_ENABLED: 'false'
@@ -242,6 +274,12 @@ services:
MAIL_TYPE: ''
# default send from email address, if not specified
MAIL_DEFAULT_SEND_FROM: 'YOUR EMAIL FROM (eg: no-reply )'
+ SMTP_SERVER: ''
+ SMTP_PORT: 465
+ SMTP_USERNAME: ''
+ SMTP_PASSWORD: ''
+ SMTP_USE_TLS: 'true'
+ SMTP_OPPORTUNISTIC_TLS: 'false'
# the api-key for resend (https://resend.com)
RESEND_API_KEY: ''
RESEND_API_URL: https://api.resend.com
@@ -251,21 +289,38 @@ services:
RELYT_USER: postgres
RELYT_PASSWORD: difyai123456
RELYT_DATABASE: postgres
+ # pgvector configurations
+ PGVECTOR_HOST: pgvector
+ PGVECTOR_PORT: 5432
+ PGVECTOR_USER: postgres
+ PGVECTOR_PASSWORD: difyai123456
+ PGVECTOR_DATABASE: dify
+ # tidb vector configurations
+ TIDB_VECTOR_HOST: tidb
+ TIDB_VECTOR_PORT: 4000
+ TIDB_VECTOR_USER: xxx.root
+ TIDB_VECTOR_PASSWORD: xxxxxx
+ TIDB_VECTOR_DATABASE: dify
# Notion import configuration, support public and internal
NOTION_INTEGRATION_TYPE: public
NOTION_CLIENT_SECRET: you-client-secret
NOTION_CLIENT_ID: you-client-id
NOTION_INTERNAL_SECRET: you-internal-secret
+ # Indexing configuration
+ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: 1000
depends_on:
- db
- redis
volumes:
# Mount the storage directory to the container, for storing user files.
- ./volumes/app/storage:/app/api/storage
+ networks:
+ - ssrf_proxy_network
+ - default
# Frontend web application.
web:
- image: langgenius/dify-web:0.6.6
+ image: langgenius/dify-web:0.6.10
restart: always
environment:
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
@@ -346,18 +401,36 @@ services:
# The DifySandbox
sandbox:
- image: langgenius/dify-sandbox:0.1.0
+ image: langgenius/dify-sandbox:0.2.1
restart: always
- cap_add:
- # Why is sys_admin permission needed?
- # https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-sys_admin-permission-needed
- - SYS_ADMIN
environment:
# The DifySandbox configurations
+ # Make sure you are changing this key for your deployment with a strong key.
+ # You can generate a strong key using `openssl rand -base64 42`.
API_KEY: dify-sandbox
- GIN_MODE: release
+ GIN_MODE: 'release'
WORKER_TIMEOUT: 15
+ ENABLE_NETWORK: 'true'
+ HTTP_PROXY: 'http://ssrf_proxy:3128'
+ HTTPS_PROXY: 'http://ssrf_proxy:3128'
+ SANDBOX_PORT: 8194
+ volumes:
+ - ./volumes/sandbox/dependencies:/dependencies
+ networks:
+ - ssrf_proxy_network
+ # ssrf_proxy server
+ # for more information, please refer to
+ # https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-ssrf_proxy-needed
+ ssrf_proxy:
+ image: ubuntu/squid:latest
+ restart: always
+ volumes:
+ # pls clearly modify the squid.conf file to fit your network environment.
+ - ./volumes/ssrf_proxy/squid.conf:/etc/squid/squid.conf
+ networks:
+ - ssrf_proxy_network
+ - default
# Qdrant vector store.
# uncomment to use qdrant as vector store.
# (if uncommented, you need to comment out the weaviate service above,
@@ -374,6 +447,31 @@ services:
# # - "6333:6333"
# # - "6334:6334"
+ # The pgvector vector database.
+ # Uncomment to use qdrant as vector store.
+ # pgvector:
+ # image: pgvector/pgvector:pg16
+ # restart: always
+ # environment:
+ # PGUSER: postgres
+ # # The password for the default postgres user.
+ # POSTGRES_PASSWORD: difyai123456
+ # # The name of the default postgres database.
+ # POSTGRES_DB: dify
+ # # postgres data directory
+ # PGDATA: /var/lib/postgresql/data/pgdata
+ # volumes:
+ # - ./volumes/pgvector/data:/var/lib/postgresql/data
+ # # uncomment to expose db(postgresql) port to host
+ # # ports:
+ # # - "5433:5432"
+ # healthcheck:
+ # test: [ "CMD", "pg_isready" ]
+ # interval: 1s
+ # timeout: 3s
+ # retries: 30
+
+
# The nginx reverse proxy.
# used for reverse proxying the API service and Web service.
nginx:
@@ -390,3 +488,8 @@ services:
ports:
- "80:80"
#- "443:443"
+networks:
+ # create a network between sandbox, api and ssrf_proxy, and can not access outside.
+ ssrf_proxy_network:
+ driver: bridge
+ internal: true
diff --git a/docker/volumes/sandbox/dependencies/python-requirements.txt b/docker/volumes/sandbox/dependencies/python-requirements.txt
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/docker/volumes/ssrf_proxy/squid.conf b/docker/volumes/ssrf_proxy/squid.conf
new file mode 100644
index 0000000000..3028bf35c6
--- /dev/null
+++ b/docker/volumes/ssrf_proxy/squid.conf
@@ -0,0 +1,50 @@
+acl localnet src 0.0.0.1-0.255.255.255 # RFC 1122 "this" network (LAN)
+acl localnet src 10.0.0.0/8 # RFC 1918 local private network (LAN)
+acl localnet src 100.64.0.0/10 # RFC 6598 shared address space (CGN)
+acl localnet src 169.254.0.0/16 # RFC 3927 link-local (directly plugged) machines
+acl localnet src 172.16.0.0/12 # RFC 1918 local private network (LAN)
+acl localnet src 192.168.0.0/16 # RFC 1918 local private network (LAN)
+acl localnet src fc00::/7 # RFC 4193 local private network range
+acl localnet src fe80::/10 # RFC 4291 link-local (directly plugged) machines
+acl SSL_ports port 443
+acl Safe_ports port 80 # http
+acl Safe_ports port 21 # ftp
+acl Safe_ports port 443 # https
+acl Safe_ports port 70 # gopher
+acl Safe_ports port 210 # wais
+acl Safe_ports port 1025-65535 # unregistered ports
+acl Safe_ports port 280 # http-mgmt
+acl Safe_ports port 488 # gss-http
+acl Safe_ports port 591 # filemaker
+acl Safe_ports port 777 # multiling http
+acl CONNECT method CONNECT
+http_access deny !Safe_ports
+http_access deny CONNECT !SSL_ports
+http_access allow localhost manager
+http_access deny manager
+http_access allow localhost
+http_access allow localnet
+http_access deny all
+
+################################## Proxy Server ################################
+http_port 3128
+coredump_dir /var/spool/squid
+refresh_pattern ^ftp: 1440 20% 10080
+refresh_pattern ^gopher: 1440 0% 1440
+refresh_pattern -i (/cgi-bin/|\?) 0 0% 0
+refresh_pattern \/(Packages|Sources)(|\.bz2|\.gz|\.xz)$ 0 0% 0 refresh-ims
+refresh_pattern \/Release(|\.gpg)$ 0 0% 0 refresh-ims
+refresh_pattern \/InRelease$ 0 0% 0 refresh-ims
+refresh_pattern \/(Translation-.*)(|\.bz2|\.gz|\.xz)$ 0 0% 0 refresh-ims
+refresh_pattern . 0 20% 4320
+logfile_rotate 0
+
+# upstream proxy, set to your own upstream proxy IP to avoid SSRF attacks
+# cache_peer 172.1.1.1 parent 3128 0 no-query no-digest no-netdb-exchange default
+
+
+################################## Reverse Proxy To Sandbox ################################
+http_port 8194 accel vhost
+cache_peer sandbox parent 8194 0 no-query originserver
+acl all src all
+http_access allow all
\ No newline at end of file
diff --git a/sdks/nodejs-client/index.d.ts b/sdks/nodejs-client/index.d.ts
index cf1d825221..7fdd943f63 100644
--- a/sdks/nodejs-client/index.d.ts
+++ b/sdks/nodejs-client/index.d.ts
@@ -14,15 +14,6 @@ interface HeaderParams {
interface User {
}
-interface ChatMessageConfig {
- inputs: any;
- query: string;
- user: User;
- stream?: boolean;
- conversation_id?: string | null;
- files?: File[] | null;
-}
-
export declare class DifyClient {
constructor(apiKey: string, baseUrl?: string);
@@ -54,7 +45,14 @@ export declare class CompletionClient extends DifyClient {
}
export declare class ChatClient extends DifyClient {
- createChatMessage(config: ChatMessageConfig): Promise;
+ createChatMessage(
+ inputs: any,
+ query: string,
+ user: User,
+ stream?: boolean,
+ conversation_id?: string | null,
+ files?: File[] | null
+ ): Promise;
getConversationMessages(
user: User,
diff --git a/sdks/nodejs-client/index.js b/sdks/nodejs-client/index.js
index 127d62cf87..93491fa16b 100644
--- a/sdks/nodejs-client/index.js
+++ b/sdks/nodejs-client/index.js
@@ -85,7 +85,7 @@ export class DifyClient {
response = await axios({
method,
url,
- data,
+ ...(method !== "GET" && { data }),
params,
headers,
responseType: "json",
diff --git a/sdks/nodejs-client/index.test.js b/sdks/nodejs-client/index.test.js
index e08b8e82af..f300b16fc9 100644
--- a/sdks/nodejs-client/index.test.js
+++ b/sdks/nodejs-client/index.test.js
@@ -42,7 +42,6 @@ describe('Send Requests', () => {
expect(axios).toHaveBeenCalledWith({
method,
url: `${BASE_URL}${endpoint}`,
- data: null,
params: null,
headers: {
Authorization: `Bearer ${difyClient.apiKey}`,
diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json
index 83b2f8a4c0..cc27c5e0c0 100644
--- a/sdks/nodejs-client/package.json
+++ b/sdks/nodejs-client/package.json
@@ -1,6 +1,6 @@
{
"name": "dify-client",
- "version": "2.3.1",
+ "version": "2.3.2",
"description": "This is the Node.js SDK for the Dify.AI API, which allows you to easily integrate Dify.AI into your Node.js applications.",
"main": "index.js",
"type": "module",
diff --git a/web/.husky/pre-commit b/web/.husky/pre-commit
index 4bc7fb77ab..8d1ad1d09f 100755
--- a/web/.husky/pre-commit
+++ b/web/.husky/pre-commit
@@ -31,7 +31,7 @@ if $api_modified; then
pip install ruff
fi
- ruff check ./api || status=$?
+ ruff check --preview ./api || status=$?
status=${status:-0}
diff --git a/web/.vscode/settings.example.json b/web/.vscode/settings.example.json
index 34b49e2708..6162d021d0 100644
--- a/web/.vscode/settings.example.json
+++ b/web/.vscode/settings.example.json
@@ -2,7 +2,7 @@
"prettier.enable": false,
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
- "source.fixAll.eslint": true
+ "source.fixAll.eslint": "explicit"
},
"eslint.format.enable": true,
"[python]": {
diff --git a/web/app/(commonLayout)/datasets/NewDatasetCard.tsx b/web/app/(commonLayout)/datasets/NewDatasetCard.tsx
index 5ce79adadc..f3e34ff7e2 100644
--- a/web/app/(commonLayout)/datasets/NewDatasetCard.tsx
+++ b/web/app/(commonLayout)/datasets/NewDatasetCard.tsx
@@ -9,7 +9,7 @@ const CreateAppCard = forwardRef((_, ref) => {
return (
-