diff --git a/.devcontainer/README.md b/.devcontainer/README.md index fa989584f5..df12a3c2d6 100644 --- a/.devcontainer/README.md +++ b/.devcontainer/README.md @@ -1,4 +1,4 @@ -# Devlopment with devcontainer +# Development with devcontainer This project includes a devcontainer configuration that allows you to open the project in a container with a fully configured development environment. Both frontend and backend environments are initialized when the container is started. ## GitHub Codespaces @@ -33,5 +33,5 @@ Performance Impact: While usually minimal, programs running inside a devcontaine if you see such error message when you open this project in codespaces: ![Alt text](troubleshooting.png) -a simple workaround is change `/signin` endpoint into another one, then login with github account and close the tab, then change it back to `/signin` endpoint. Then all things will be fine. +a simple workaround is change `/signin` endpoint into another one, then login with GitHub account and close the tab, then change it back to `/signin` endpoint. Then all things will be fine. The reason is `signin` endpoint is not allowed in codespaces, details can be found [here](https://github.com/orgs/community/discussions/5204) \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index fea45de1d3..b596bdb6b0 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -8,13 +8,13 @@ body: label: Self Checks description: "To make sure we get to you in time, please check the following :)" options: - - label: This is only for bug report, if you would like to ask a quesion, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). + - label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). required: true - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). required: true - - label: "Pleas do not modify this template :) and fill in all the required fields." + - label: "Please do not modify this template :) and fill in all the required fields." required: true - type: input diff --git a/.github/ISSUE_TEMPLATE/document_issue.yml b/.github/ISSUE_TEMPLATE/document_issue.yml index 44115b2097..c5aeb7fd73 100644 --- a/.github/ISSUE_TEMPLATE/document_issue.yml +++ b/.github/ISSUE_TEMPLATE/document_issue.yml @@ -1,7 +1,7 @@ name: "📚 Documentation Issue" description: Report issues in our documentation labels: - - ducumentation + - documentation body: - type: checkboxes attributes: @@ -12,7 +12,7 @@ body: required: true - label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). required: true - - label: "Pleas do not modify this template :) and fill in all the required fields." + - label: "Please do not modify this template :) and fill in all the required fields." required: true - type: textarea attributes: diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 694bd3975d..8730f5c11f 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -12,7 +12,7 @@ body: required: true - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). required: true - - label: "Pleas do not modify this template :) and fill in all the required fields." + - label: "Please do not modify this template :) and fill in all the required fields." required: true - type: textarea attributes: diff --git a/.github/ISSUE_TEMPLATE/translation_issue.yml b/.github/ISSUE_TEMPLATE/translation_issue.yml index 589e071e14..898e2cdf58 100644 --- a/.github/ISSUE_TEMPLATE/translation_issue.yml +++ b/.github/ISSUE_TEMPLATE/translation_issue.yml @@ -12,7 +12,7 @@ body: required: true - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). required: true - - label: "Pleas do not modify this template :) and fill in all the required fields." + - label: "Please do not modify this template :) and fill in all the required fields." required: true - type: input attributes: diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index a0407de843..82cfdcea06 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -4,9 +4,17 @@ on: pull_request: branches: - main + paths: + - api/** + - docker/** + +concurrency: + group: api-tests-${{ github.head_ref || github.run_id }} + cancel-in-progress: true jobs: test: + name: API Tests runs-on: ubuntu-latest strategy: matrix: @@ -46,11 +54,12 @@ jobs: docker/docker-compose.middleware.yaml services: | sandbox + ssrf_proxy - name: Run Workflow run: dev/pytest/pytest_workflow.sh - - name: Set up Vector Stores (Weaviate, Qdrant, Milvus, PgVecto-RS) + - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS) uses: hoverkraft-tech/compose-action@v2.0.0 with: compose-file: | @@ -58,6 +67,7 @@ jobs: docker/docker-compose.qdrant.yaml docker/docker-compose.milvus.yaml docker/docker-compose.pgvecto-rs.yaml + docker/docker-compose.pgvector.yaml services: | weaviate qdrant @@ -65,6 +75,7 @@ jobs: minio milvus-standalone pgvecto-rs + pgvector - name: Test Vector Stores run: dev/pytest/pytest_vdb.sh diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml new file mode 100644 index 0000000000..cb8dd06c5e --- /dev/null +++ b/.github/workflows/db-migration-test.yml @@ -0,0 +1,53 @@ +name: DB Migration Test + +on: + pull_request: + branches: + - main + paths: + - api/migrations/** + +concurrency: + group: db-migration-test-${{ github.ref }} + cancel-in-progress: true + +jobs: + db-migration-test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - "3.10" + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: | + ./api/requirements.txt + + - name: Install dependencies + run: pip install -r ./api/requirements.txt + + - name: Set up Middleware + uses: hoverkraft-tech/compose-action@v2.0.0 + with: + compose-file: | + docker/docker-compose.middleware.yaml + services: | + db + + - name: Prepare configs + run: | + cd api + cp .env.example .env + + - name: Run DB Migration + run: | + cd api + flask db upgrade diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index bdbc22b489..7dad707fea 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -6,7 +6,7 @@ on: - main concurrency: - group: dep-${{ github.head_ref || github.run_id }} + group: style-${{ github.head_ref || github.run_id }} cancel-in-progress: true jobs: @@ -18,54 +18,89 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Check changed files + id: changed-files + uses: tj-actions/changed-files@v44 + with: + files: api/** + - name: Set up Python uses: actions/setup-python@v5 + if: steps.changed-files.outputs.any_changed == 'true' with: python-version: '3.10' - name: Python dependencies + if: steps.changed-files.outputs.any_changed == 'true' run: pip install ruff dotenv-linter - name: Ruff check - run: ruff check ./api + if: steps.changed-files.outputs.any_changed == 'true' + run: ruff check --preview ./api - name: Dotenv check + if: steps.changed-files.outputs.any_changed == 'true' run: dotenv-linter ./api/.env.example ./web/.env.example - name: Lint hints if: failure() run: echo "Please run 'dev/reformat' to fix the fixable linting errors." - test: - name: ESLint and SuperLinter + web-style: + name: Web Style runs-on: ubuntu-latest - needs: python-style + defaults: + run: + working-directory: ./web steps: - name: Checkout code uses: actions/checkout@v4 + + - name: Check changed files + id: changed-files + uses: tj-actions/changed-files@v44 with: - fetch-depth: 0 + files: web/** - name: Setup NodeJS uses: actions/setup-node@v4 + if: steps.changed-files.outputs.any_changed == 'true' with: node-version: 20 cache: yarn cache-dependency-path: ./web/package.json - name: Web dependencies - run: | - cd ./web - yarn install --frozen-lockfile + if: steps.changed-files.outputs.any_changed == 'true' + run: yarn install --frozen-lockfile - name: Web style check - run: | - cd ./web - yarn run lint + if: steps.changed-files.outputs.any_changed == 'true' + run: yarn run lint + + + superlinter: + name: SuperLinter + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Check changed files + id: changed-files + uses: tj-actions/changed-files@v44 + with: + files: | + **.sh + **.yaml + **.yml + Dockerfile - name: Super-linter uses: super-linter/super-linter/slim@v6 + if: steps.changed-files.outputs.any_changed == 'true' env: BASH_SEVERITY: warning DEFAULT_BRANCH: main @@ -76,4 +111,5 @@ jobs: VALIDATE_BASH_EXEC: true VALIDATE_GITHUB_ACTIONS: true VALIDATE_DOCKERFILE_HADOLINT: true + VALIDATE_XML: true VALIDATE_YAML: true diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml index 575ead4b3b..fb4bcb9d66 100644 --- a/.github/workflows/tool-test-sdks.yaml +++ b/.github/workflows/tool-test-sdks.yaml @@ -4,6 +4,13 @@ on: pull_request: branches: - main + paths: + - sdks/** + +concurrency: + group: sdk-tests-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: build: name: unit test for Node.js SDK diff --git a/.gitignore b/.gitignore index c957d63174..a51465efbc 100644 --- a/.gitignore +++ b/.gitignore @@ -134,7 +134,8 @@ dmypy.json web/.vscode/settings.json # Intellij IDEA Files -.idea/ +.idea/* +!.idea/vcs.xml .ideaDataSources/ api/.env diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000..ae8b1755c5 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,16 @@ + + + + + + + + + \ No newline at end of file diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md new file mode 100644 index 0000000000..c9329d6102 --- /dev/null +++ b/CONTRIBUTING_JA.md @@ -0,0 +1,160 @@ +Dify にコントリビュートしたいとお考えなのですね。それは素晴らしいことです。 +私たちは、LLM アプリケーションの構築と管理のための最も直感的なワークフローを設計するという壮大な野望を持っています。人数も資金も限られている新興企業として、コミュニティからの支援は本当に重要です。 + +私たちは現状を鑑み、機敏かつ迅速に開発をする必要がありますが、同時にあなたのようなコントリビューターの方々に、可能な限りスムーズな貢献体験をしていただきたいと思っています。そのためにこのコントリビュートガイドを作成しました。 +コードベースやコントリビュータの方々と私たちがどのように仕事をしているのかに慣れていただき、楽しいパートにすぐに飛び込めるようにすることが目的です。 + +このガイドは Dify そのものと同様に、継続的に改善されています。実際のプロジェクトに遅れをとることがあるかもしれませんが、ご理解をお願いします。 + +ライセンスに関しては、私たちの短い[ライセンスおよびコントリビューター規約](./LICENSE)をお読みください。また、コミュニティは[行動規範](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)を遵守しています。 + +## 飛び込む前に + +[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。 + +### 機能リクエスト + +* 新しい機能要望を出す場合は、提案する機能が何を実現するものなのかを説明し、可能な限り多くの文脈を含めてください。[@perzeusss](https://github.com/perzeuss)は、あなたの要望を書き出すのに役立つ [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) を作ってくれました。気軽に試してみてください。 + +* 既存の課題から 1 つ選びたい場合は、その下にコメントを書いてください。 + + 関連する方向で作業しているチームメンバーが参加します。すべてが良好であれば、コーディングを開始する許可が与えられます。私たちが変更を提案した場合にあなたの作業が無駄になることがないよう、それまでこの機能の作業を控えていただくようお願いいたします。 + + 提案された機能がどの分野に属するかによって、あなたは異なるチーム・メンバーと話をするかもしれません。以下は、各チームメンバーが現在取り組んでいる分野の概要です。 + +| Member | Scope | +| --------------------------------------------------------------------------------------- | ------------------------------------ | +| [@yeuoly](https://github.com/Yeuoly) | エージェントアーキテクチャ | +| [@jyong](https://github.com/JohnJyong) | RAG パイプライン設計 | +| [@GarfieldDai](https://github.com/GarfieldDai) | workflow orchestrations の構築 | +| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | フロントエンドを使いやすくする | +| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 開発者体験、何でも相談できる窓口 | +| [@takatost](https://github.com/takatost) | 全体的な製品の方向性とアーキテクチャ | + +優先順位の付け方: + +| Feature Type | Priority | +| --------------------------------------------------------------------------------------------------------------------- | --------------- | +| チームメンバーによってラベル付けされた優先度の高い機能 | High Priority | +| [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks)の人気の機能リクエスト | Medium Priority | +| 非コア機能とマイナーな機能強化 | Low Priority | +| 価値はあるが即効性はない | Future-Feature | + +### その他 (バグレポート、パフォーマンスの最適化、誤字の修正など) + +* すぐにコーディングを始めてください + +優先順位の付け方: + +| Issue Type | Priority | +| -------------------------------------------------------------------------------------- | --------------- | +| コア機能のバグ(ログインできない、アプリケーションが動作しない、セキュリティの抜け穴) | Critical | +| 致命的でないバグ、パフォーマンス向上 | Medium Priority | +| 細かな修正(誤字脱字、機能はするが分かりにくい UI) | Low Priority | + +## インストール + +Dify を開発用にセットアップする手順は以下の通りです。 + +### 1. このリポジトリをフォークする + +### 2. リポジトリをクローンする + +フォークしたリポジトリをターミナルからクローンします。 + +``` +git clone git@github.com:/dify.git +``` + +### 3. 依存関係の確認 + +Dify を構築するには次の依存関係が必要です。それらがシステムにインストールされていることを確認してください。 + +- [Docker](https://www.docker.com/) +- [Docker Compose](https://docs.docker.com/compose/install/) +- [Node.js v18.x (LTS)](http://nodejs.org) +- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) +- [Python](https://www.python.org/) version 3.10.x + +### 4. インストール + +Dify はバックエンドとフロントエンドから構成されています。 +まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。 +次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。 + +よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) を確認してください。 + +### 5. ブラウザで dify にアクセスする + +設定を確認するために、ブラウザで[http://localhost:3000](http://localhost:3000)(デフォルト、または自分で設定した URL とポート)にアクセスしてください。Dify が起動して実行中であることが確認できるはずです。 + +## 開発中 + +モデルプロバイダーを追加する場合は、[このガイド](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)が役立ちます。 + +Agent や Workflow にツールプロバイダーを追加する場合は、[このガイド](./api/core/tools/README.md)が役立ちます。 + +Dify のバックエンドとフロントエンドの概要を簡単に説明します。 + +### バックエンド + +Dify のバックエンドは[Flask](https://flask.palletsprojects.com/en/3.0.x/)を使って Python で書かれています。ORM には[SQLAlchemy](https://www.sqlalchemy.org/)を、タスクキューには[Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html)を使っています。認証ロジックは Flask-login 経由で行われます。 + +``` +[api/] +├── constants // コードベース全体で使用される定数設定 +├── controllers // APIルート定義とリクエスト処理ロジック +├── core // アプリケーションの中核的な管理、モデル統合、およびツール +├── docker // Dockerおよびコンテナ関連の設定 +├── events // イベントのハンドリングと処理 +├── extensions // 第三者のフレームワーク/プラットフォームとの拡張 +├── fields // シリアライゼーション/マーシャリング用のフィールド定義 +├── libs // 再利用可能なライブラリとヘルパー +├── migrations // データベースマイグレーションスクリプト +├── models // データベースモデルとスキーマ定義 +├── services // ビジネスロジックの定義 +├── storage // 秘密鍵の保存 +├── tasks // 非同期タスクとバックグラウンドジョブの処理 +└── tests // テスト関連のファイル +``` + +### フロントエンド + +このウェブサイトは、Typescript の[Next.js](https://nextjs.org/)ボイラープレートでブートストラップされており、スタイリングには[Tailwind CSS](https://tailwindcss.com/)を使用しています。国際化には[React-i18next](https://react.i18next.com/)を使用しています。 + +``` +[web/] +├── app // レイアウト、ページ、コンポーネント +│ ├── (commonLayout) // アプリ全体で共通のレイアウト +│ ├── (shareLayout) // トークン特有のセッションで共有されるレイアウト +│ ├── activate // アクティベートページ +│ ├── components // ページやレイアウトで共有されるコンポーネント +│ ├── install // インストールページ +│ ├── signin // サインインページ +│ └── styles // グローバルに共有されるスタイル +├── assets // 静的アセット +├── bin // ビルドステップで実行されるスクリプト +├── config // 調整可能な設定とオプション +├── context // アプリの異なる部分で使用される共有コンテキスト +├── dictionaries // 言語別の翻訳ファイル +├── docker // コンテナ設定 +├── hooks // 再利用可能なフック +├── i18n // 国際化設定 +├── models // データモデルとAPIレスポンスの形状を記述 +├── public // ファビコンなどのメタアセット +├── service // APIアクションの形状を指定 +├── test +├── types // 関数のパラメータと戻り値の記述 +└── utils // 共有ユーティリティ関数 +``` + +## PR を投稿する + +いよいよ、私たちのリポジトリにプルリクエスト (PR) を提出する時が来ました。主要な機能については、まず `deploy/dev` ブランチにマージしてテストしてから `main` ブランチにマージします。 +マージ競合などの問題が発生した場合、またはプル リクエストを開く方法がわからない場合は、[GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests) をチェックしてみてください。 + +これで完了です!あなたの PR がマージされると、[README](https://github.com/langgenius/dify/blob/main/README.md) にコントリビューターとして紹介されます。 + +## ヘルプを得る + +コントリビュート中に行き詰まったり、疑問が生じたりした場合は、GitHub の関連する issue から質問していただくか、[Discord](https://discord.gg/8Tpq4AcN9c)でチャットしてください。 diff --git a/README.md b/README.md index 0dabb16c67..c43b52d7ad 100644 --- a/README.md +++ b/README.md @@ -35,13 +35,10 @@ README en Español README en Français README tlhIngan Hol + README in Korean

-# -

- langgenius%2Fdify | Trendshift -

Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:

@@ -109,7 +106,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com Agent ✅ ✅ - ✅ + ❌ ✅ @@ -127,7 +124,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com ❌ - Enterprise Feature (SSO/Access control) + Enterprise Features (SSO/Access control) ✅ ❌ ❌ diff --git a/README_CN.md b/README_CN.md index 6a7f178e63..c6e81b532a 100644 --- a/README_CN.md +++ b/README_CN.md @@ -35,6 +35,7 @@ 上个月的提交次数 上个月的提交次数 上个月的提交次数 + 上个月的提交次数 @@ -111,7 +112,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI Agent ✅ ✅ - ✅ + ❌ ✅ diff --git a/README_ES.md b/README_ES.md index ae6ab4e382..efc1bdfd41 100644 --- a/README_ES.md +++ b/README_ES.md @@ -35,6 +35,7 @@ Actividad de Commits el último mes Actividad de Commits el último mes Actividad de Commits el último mes + Actividad de Commits el último mes

# @@ -111,7 +112,7 @@ es basados en LLM Function Calling o ReAct, y agregar herramientas preconstruida Agente ✅ ✅ - ✅ + ❌ ✅ diff --git a/README_FR.md b/README_FR.md index ae7df183e2..4f12f3788e 100644 --- a/README_FR.md +++ b/README_FR.md @@ -35,6 +35,7 @@ Commits le mois dernier Commits le mois dernier Commits le mois dernier + Commits le mois dernier

# @@ -111,7 +112,7 @@ ités d'agent**: Agent ✅ ✅ - ✅ + ❌ ✅ diff --git a/README_JA.md b/README_JA.md index af97252eae..11de404c7d 100644 --- a/README_JA.md +++ b/README_JA.md @@ -2,7 +2,7 @@

Dify Cloud · - 自己ホスティング · + セルフホスト · ドキュメント · デモのスケジュール

@@ -35,6 +35,7 @@ 先月のコミット 先月のコミット 先月のコミット + 先月のコミット

# @@ -54,7 +55,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ -**2. 網羅的なモデルサポート**: +**2. 包括的なモデルサポート**: 数百のプロプライエタリ/オープンソースのLLMと、数十の推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama3、およびOpenAI API互換のモデルをカバーします。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs.dify.ai/getting-started/readme/model-providers)をご覧ください。 ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) @@ -94,9 +95,9 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ サポートされているLLM - 豊富なバリエーション - 豊富なバリエーション - 豊富なバリエーション + バリエーション豊富 + バリエーション豊富 + バリエーション豊富 OpenAIのみ @@ -110,7 +111,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ エージェント ✅ ✅ - ✅ + ❌ ✅ @@ -146,34 +147,34 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ ## Difyの使用方法 - **クラウド
** -[こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップが不要で誰でも試すことができます。サンドボックスプランでは、200回の無料のGPT-4呼び出しが含まれています。 +[こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回の無料のGPT-4呼び出しが含まれています。 - **Dify Community Editionのセルフホスティング
** -この[スターターガイド](#quick-start)を使用して、環境でDifyをすばやく実行できます。 -さらなる参照や詳細な手順については、[ドキュメント](https://docs.dify.ai)をご覧ください。 +この[スターターガイド](#quick-start)を使用して、ローカル環境でDifyを簡単に実行できます。 +さらなる参考資料や詳細な手順については、[ドキュメント](https://docs.dify.ai)をご覧ください。 - **エンタープライズ/組織向けのDify
** 追加のエンタープライズ向け機能を提供しています。[こちらからミーティングを予約](https://cal.com/guchenhe/30min)したり、[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)してエンタープライズのニーズについて相談してください。
> AWSを使用しているスタートアップや中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで独自のAWS VPCにデプロイできます。カスタムロゴとブランディングでアプリを作成するオプションを備えた手頃な価格のAMIオファリングです。 -## 先を見る +## 最新の情報を入手 -GitHubでDifyにスターを付け、新しいリリースをすぐに通知されます。 +GitHub上でDifyにスターを付けることで、Difyに関する新しいニュースを受け取れます。 ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) ## クイックスタート -> Difyをインストールする前に、マシンが以下の最小システム要件を満たしていることを確認してください: +> Difyをインストールする前に、お使いのマシンが以下の最小システム要件を満たしていることを確認してください: > >- CPU >= 2コア >- RAM >= 4GB
-Difyサーバーを起動する最も簡単な方法は、当社の[docker-compose.yml](docker/docker-compose.yaml)ファイルを実行することです。インストールコマンドを実行する前に、マシンに[Docker](https://docs.docker.com/get-docker/)と[Docker Compose](https://docs.docker.com/compose/install/)がインストールされていることを確認してください。 +Difyサーバーを起動する最も簡単な方法は、[docker-compose.yml](docker/docker-compose.yaml)ファイルを実行することです。インストールコマンドを実行する前に、マシンに[Docker](https://docs.docker.com/get-docker/)と[Docker Compose](https://docs.docker.com/compose/install/)がインストールされていることを確認してください。 ```bash cd docker @@ -216,7 +217,7 @@ docker compose up -d * [Discord](https://discord.gg/FngNHpbcY7). 主に: アプリケーションの共有やコミュニティとの交流。 * [Twitter](https://twitter.com/dify_ai). 主に: アプリケーションの共有やコミュニティとの交流。 -または、直接チームメンバーとミーティングをスケジュールします: +または、直接チームメンバーとミーティングをスケジュール: @@ -227,7 +228,7 @@ docker compose up -d - + @@ -242,4 +243,4 @@ docker compose up -d ## ライセンス -このリポジトリは、Dify Open Source License にいくつかの追加制限を加えた[Difyオープンソースライセンス](LICENSE)の下で利用可能です。 \ No newline at end of file +このリポジトリは、Dify Open Source License にいくつかの追加制限を加えた[Difyオープンソースライセンス](LICENSE)の下で利用可能です。 diff --git a/README_KL.md b/README_KL.md index 600649c459..b1eb5073f6 100644 --- a/README_KL.md +++ b/README_KL.md @@ -35,6 +35,7 @@ Commits last monthCommits last monthCommits last month + Commits last month

# @@ -111,7 +112,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com - + diff --git a/README_KR.md b/README_KR.md new file mode 100644 index 0000000000..9c809fa017 --- /dev/null +++ b/README_KR.md @@ -0,0 +1,243 @@ +![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) + +

+ Dify 클라우드 · + 셀프-호스팅 · + 문서 · + 기업 문의 +

+ +

+ + Static Badge + + Static Badge + + chat on Discord + + follow on Twitter + + Docker Pulls + + Commits last month + + Issues closed + + Discussion posts +

+ +

+ README in English + 简体中文版自述文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + 한국어 README + +

+ + + Dify는 오픈 소스 LLM 앱 개발 플랫폼입니다. 직관적인 인터페이스를 통해 AI 워크플로우, RAG 파이프라인, 에이전트 기능, 모델 관리, 관찰 기능 등을 결합하여 프로토타입에서 프로덕션까지 빠르게 전환할 수 있습니다. 주요 기능 목록은 다음과 같습니다:

+ +**1. 워크플로우**: + 다음 기능들을 비롯한 다양한 기능을 활용하여 시각적 캔버스에서 강력한 AI 워크플로우를 구축하고 테스트하세요. + + + https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa + + + +**2. 포괄적인 모델 지원:**: + +수십 개의 추론 제공업체와 자체 호스팅 솔루션에서 제공하는 수백 개의 독점 및 오픈 소스 LLM과 원활하게 통합되며, GPT, Mistral, Llama3 및 모든 OpenAI API 호환 모델을 포함합니다. 지원되는 모델 제공업체의 전체 목록은 [여기](https://docs.dify.ai/getting-started/readme/model-providers)에서 확인할 수 있습니다. +![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) + + +**3. 통합 개발환경**: + 프롬프트를 작성하고, 모델 성능을 비교하며, 텍스트-음성 변환과 같은 추가 기능을 채팅 기반 앱에 추가할 수 있는 직관적인 인터페이스를 제공합니다. + +**4. RAG 파이프라인**: + 문서 수집부터 검색까지 모든 것을 다루며, PDF, PPT 및 기타 일반적인 문서 형식에서 텍스트 추출을 위한 기본 지원이 포함되어 있는 광범위한 RAG 기능을 제공합니다. + +**5. 에이전트 기능**: + LLM 함수 호출 또는 ReAct를 기반으로 에이전트를 정의하고 에이전트에 대해 사전 구축된 도구나 사용자 정의 도구를 추가할 수 있습니다. Dify는 Google Search, DELL·E, Stable Diffusion, WolframAlpha 등 AI 에이전트를 위한 50개 이상의 내장 도구를 제공합니다. + +**6. LLMOps**: + 시간 경과에 따른 애플리케이션 로그와 성능을 모니터링하고 분석합니다. 생산 데이터와 주석을 기반으로 프롬프트, 데이터세트, 모델을 지속적으로 개선할 수 있습니다. + +**7. Backend-as-a-Service**: + Dify의 모든 제품에는 해당 API가 함께 제공되므로 Dify를 자신의 비즈니스 로직에 쉽게 통합할 수 있습니다. + +## 기능 비교 +
ミーティング無料の30分間のミーティングをスケジュールしてください。無料の30分間のミーティングをスケジュール
技術サポート Agent
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
기능Dify.AILangChainFlowiseOpenAI Assistants API
프로그래밍 접근 방식API + 앱 중심Python 코드앱 중심API 중심
지원되는 LLMs다양한 종류다양한 종류다양한 종류OpenAI 전용
RAG 엔진
에이전트
워크플로우
가시성
기업용 기능 (SSO/접근 제어)
로컬 배포
+ +## Dify 사용하기 + +- **클라우드
** + 우리는 누구나 설정이 필요 없이 사용해 볼 수 있도록 [Dify 클라우드](https://dify.ai) 서비스를 호스팅합니다. 이는 자체 배포 버전의 모든 기능을 제공하며, 샌드박스 플랜에서 무료로 200회의 GPT-4 호출을 포함합니다. + +- **셀프-호스팅 Dify 커뮤니티 에디션
** + 환경에서 Dify를 빠르게 실행하려면 이 [스타터 가이드를](#quick-start) 참조하세요. + 추가 참조 및 더 심층적인 지침은 [문서](https://docs.dify.ai)를 사용하세요. + +- **기업 / 조직을 위한 Dify
** + 우리는 추가적인 기업 중심 기능을 제공합니다. 당사와 [미팅일정](https://cal.com/guchenhe/30min)을 잡거나 [이메일 보내기](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)를 통해 기업 요구 사항을 논의하십시오.
+ > AWS를 사용하는 스타트업 및 중소기업의 경우 [AWS Marketplace에서 Dify Premium](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)을 확인하고 한 번의 클릭으로 자체 AWS VPC에 배포하십시오. 맞춤형 로고와 브랜딩이 포함된 앱을 생성할 수 있는 옵션이 포함된 저렴한 AMI 제품입니다. + + + +## 앞서가기 + +GitHub에서 Dify에 별표를 찍어 새로운 릴리스를 즉시 알림 받으세요. + +![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) + + + +## 빠른 시작 +>Dify를 설치하기 전에 컴퓨터가 다음과 같은 최소 시스템 요구 사항을 충족하는지 확인하세요 : +>- CPU >= 2 Core +>- RAM >= 4GB + +
+ +Dify 서버를 시작하는 가장 쉬운 방법은 [docker-compose.yml](docker/docker-compose.yaml) 파일을 실행하는 것입니다. 설치 명령을 실행하기 전에 [Docker](https://docs.docker.com/get-docker/) 및 [Docker Compose](https://docs.docker.com/compose/install/)가 머신에 설치되어 있는지 확인하세요. + +```bash +cd docker +docker compose up -d +``` + +실행 후 브라우저의 [http://localhost/install](http://localhost/install) 에서 Dify 대시보드에 액세스하고 초기화 프로세스를 시작할 수 있습니다. + +> Dify에 기여하거나 추가 개발을 하고 싶다면 소스 코드에서 [배포에 대한 가이드](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)를 참조하세요. + +## 다음 단계 + +구성 커스터마이징이 필요한 경우, [docker-compose.yml](docker/docker-compose.yaml) 파일의 코멘트를 참조하여 환경 구성을 수동으로 설정하십시오. 변경 후 `docker-compose up -d` 를 다시 실행하십시오. 환경 변수의 전체 목록은 [여기](https://docs.dify.ai/getting-started/install-self-hosted/environments)에서 확인할 수 있습니다. + + +고가용성 설정을 구성하려면 Dify를 Kubernetes에 배포할 수 있는 커뮤니티 제공 [Helm Charts](https://helm.sh/)가 있습니다. + +- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) +- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) + + +## 기여 + +코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. +동시에 Dify를 소셜 미디어와 행사 및 컨퍼런스에 공유하여 지원하는 것을 고려해 주시기 바랍니다. + + +> 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요. + +**기여자** + + + + + +## 커뮤니티 & 연락처 + +* [Github 토론](https://github.com/langgenius/dify/discussions). 피드백 공유 및 질문하기에 적합합니다. +* [GitHub 이슈](https://github.com/langgenius/dify/issues). Dify.AI 사용 중 발견한 버그와 기능 제안에 적합합니다. [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. +* [이메일](mailto:support@dify.ai?subject=[GitHub]Questions%20About%20Dify). Dify.AI 사용에 대한 질문하기에 적합합니다. +* [디스코드](https://discord.gg/FngNHpbcY7). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. +* [트위터](https://twitter.com/dify_ai). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. + +또는 팀원과 직접 미팅을 예약하세요: + + + + + + + + + + + + + + +
연락처목적
Git-Hub-README-Button-3x비즈니스 문의 및 제품 피드백
Git-Hub-README-Button-2x기여, 이슈 및 기능 요청
+ +## Star 히스토리 + +[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) + + +## 보안 공개 + +개인정보 보호를 위해 보안 문제를 GitHub에 게시하지 마십시오. 대신 security@dify.ai로 질문을 보내주시면 더 자세한 답변을 드리겠습니다. + +## 라이선스 + +이 저장소는 기본적으로 몇 가지 추가 제한 사항이 있는 Apache 2.0인 [Dify 오픈 소스 라이선스](LICENSE)에 따라 사용할 수 있습니다. diff --git a/api/.env.example b/api/.env.example index 30bbf331a4..f112721a7e 100644 --- a/api/.env.example +++ b/api/.env.example @@ -17,6 +17,9 @@ APP_WEB_URL=http://127.0.0.1:3000 # Files URL FILES_URL=http://127.0.0.1:5001 +# The time in seconds after the signature is rejected +FILES_ACCESS_TIMEOUT=300 + # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 @@ -65,7 +68,7 @@ GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-stri WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs +# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector VECTOR_STORE=weaviate # Weaviate configuration @@ -102,6 +105,20 @@ PGVECTO_RS_USER=postgres PGVECTO_RS_PASSWORD=difyai123456 PGVECTO_RS_DATABASE=postgres +# PGVector configuration +PGVECTOR_HOST=127.0.0.1 +PGVECTOR_PORT=5433 +PGVECTOR_USER=postgres +PGVECTOR_PASSWORD=postgres +PGVECTOR_DATABASE=postgres + +# Tidb Vector configuration +TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com +TIDB_VECTOR_PORT=4000 +TIDB_VECTOR_USER=xxx.root +TIDB_VECTOR_PASSWORD=xxxxxx +TIDB_VECTOR_DATABASE=dify + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 @@ -117,10 +134,11 @@ RESEND_API_KEY= RESEND_API_URL=https://api.resend.com # smtp configuration SMTP_SERVER=smtp.gmail.com -SMTP_PORT=587 +SMTP_PORT=465 SMTP_USERNAME=123 SMTP_PASSWORD=abc -SMTP_USE_TLS=false +SMTP_USE_TLS=true +SMTP_OPPORTUNISTIC_TLS=false # Sentry configuration SENTRY_DSN= @@ -137,6 +155,7 @@ NOTION_INTERNAL_SECRET=you-internal-secret ETL_TYPE=dify UNSTRUCTURED_API_URL= +UNSTRUCTURED_API_KEY= SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTPS_URL= @@ -163,6 +182,16 @@ API_TOOL_DEFAULT_READ_TIMEOUT=60 HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300 HTTP_REQUEST_MAX_READ_TIMEOUT=600 HTTP_REQUEST_MAX_WRITE_TIMEOUT=600 +HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 # 10MB +HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # 1MB # Log file path LOG_FILE= + +# Indexing configuration +INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000 + +# Workflow runtime configuration +WORKFLOW_MAX_EXECUTION_STEPS=500 +WORKFLOW_MAX_EXECUTION_TIME=1200 +WORKFLOW_CALL_MAX_DEPTH=5 diff --git a/api/commands.py b/api/commands.py index b82f7ac3f8..186b97c3fa 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,11 +1,13 @@ import base64 import json import secrets +from typing import Optional import click from flask import current_app from werkzeug.exceptions import NotFound +from constants.languages import languages from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document from extensions.ext_database import db @@ -17,6 +19,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel +from services.account_service import RegisterService, TenantService @click.command('reset-password', help='Reset the account password.') @@ -57,7 +60,7 @@ def reset_password(email, new_password, password_confirm): account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() - click.echo(click.style('Congratulations!, password has been reset.', fg='green')) + click.echo(click.style('Congratulations! Password has been reset.', fg='green')) @click.command('reset-email', help='Reset the account email.') @@ -305,6 +308,14 @@ def migrate_knowledge_vector_database(): "vector_store": {"class_prefix": collection_name} } dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == "pgvector": + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = { + "type": 'pgvector', + "vector_store": {"class_prefix": collection_name} + } + dataset.index_struct = json.dumps(index_struct_dict) else: raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") @@ -440,9 +451,105 @@ def convert_to_agent_apps(): click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) +@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.') +@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.') +def add_qdrant_doc_id_index(field: str): + click.echo(click.style('Start add qdrant doc_id index.', fg='green')) + config = current_app.config + vector_type = config.get('VECTOR_STORE') + if vector_type != "qdrant": + click.echo(click.style('Sorry, only support qdrant vector store.', fg='red')) + return + create_count = 0 + + try: + bindings = db.session.query(DatasetCollectionBinding).all() + if not bindings: + click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red')) + return + import qdrant_client + from qdrant_client.http.exceptions import UnexpectedResponse + from qdrant_client.http.models import PayloadSchemaType + + from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig + for binding in bindings: + qdrant_config = QdrantConfig( + endpoint=config.get('QDRANT_URL'), + api_key=config.get('QDRANT_API_KEY'), + root_path=current_app.root_path, + timeout=config.get('QDRANT_CLIENT_TIMEOUT'), + grpc_port=config.get('QDRANT_GRPC_PORT'), + prefer_grpc=config.get('QDRANT_GRPC_ENABLED') + ) + try: + client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) + # create payload index + client.create_payload_index(binding.collection_name, field, + field_schema=PayloadSchemaType.KEYWORD) + create_count += 1 + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red')) + continue + # Some other error occurred, so re-raise the exception + else: + click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red')) + + except Exception as e: + click.echo(click.style('Failed to create qdrant client.', fg='red')) + + click.echo( + click.style(f'Congratulations! Create {create_count} collection indexes.', + fg='green')) + + +@click.command('create-tenant', help='Create account and tenant.') +@click.option('--email', prompt=True, help='The email address of the tenant account.') +@click.option('--language', prompt=True, help='Account language, default: en-US.') +def create_tenant(email: str, language: Optional[str] = None): + """ + Create tenant account + """ + if not email: + click.echo(click.style('Sorry, email is required.', fg='red')) + return + + # Create account + email = email.strip() + + if '@' not in email: + click.echo(click.style('Sorry, invalid email address.', fg='red')) + return + + account_name = email.split('@')[0] + + if language not in languages: + language = 'en-US' + + # generate random password + new_password = secrets.token_urlsafe(16) + + # register account + account = RegisterService.register( + email=email, + name=account_name, + password=new_password, + language=language + ) + + TenantService.create_owner_tenant_if_not_exist(account) + + click.echo(click.style('Congratulations! Account and tenant created.\n' + 'Account: {}\nPassword: {}'.format(email, new_password), fg='green')) + + def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(vdb_migrate) app.cli.add_command(convert_to_agent_apps) + app.cli.add_command(add_qdrant_doc_id_index) + app.cli.add_command(create_tenant) + diff --git a/api/config.py b/api/config.py index d6c345c579..286b3336a2 100644 --- a/api/config.py +++ b/api/config.py @@ -23,14 +23,17 @@ DEFAULTS = { 'SERVICE_API_URL': 'https://api.dify.ai', 'APP_WEB_URL': 'https://udify.app', 'FILES_URL': '', + 'FILES_ACCESS_TIMEOUT': 300, 'S3_ADDRESS_STYLE': 'auto', 'STORAGE_TYPE': 'local', 'STORAGE_LOCAL_PATH': 'storage', 'CHECK_UPDATE_URL': 'https://updates.dify.ai', 'DEPLOY_ENV': 'PRODUCTION', + 'SQLALCHEMY_DATABASE_URI_SCHEME': 'postgresql', 'SQLALCHEMY_POOL_SIZE': 30, 'SQLALCHEMY_MAX_OVERFLOW': 10, 'SQLALCHEMY_POOL_RECYCLE': 3600, + 'SQLALCHEMY_POOL_PRE_PING': 'False', 'SQLALCHEMY_ECHO': 'False', 'SENTRY_TRACES_SAMPLE_RATE': 1.0, 'SENTRY_PROFILES_SAMPLE_RATE': 1.0, @@ -67,6 +70,7 @@ DEFAULTS = { 'INVITE_EXPIRY_HOURS': 72, 'BILLING_ENABLED': 'False', 'CAN_REPLACE_LOGO': 'False', + 'MODEL_LB_ENABLED': 'False', 'ETL_TYPE': 'dify', 'KEYWORD_STORE': 'jieba', 'BATCH_UPLOAD_LIMIT': 20, @@ -77,6 +81,10 @@ DEFAULTS = { 'KEYWORD_DATA_SOURCE_TYPE': 'database', 'INNER_API': 'False', 'ENTERPRISE_ENABLED': 'False', + 'INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH': 1000, + 'WORKFLOW_MAX_EXECUTION_STEPS': 500, + 'WORKFLOW_MAX_EXECUTION_TIME': 1200, + 'WORKFLOW_CALL_MAX_DEPTH': 5, } @@ -107,7 +115,7 @@ class Config: # ------------------------ # General Configurations. # ------------------------ - self.CURRENT_VERSION = "0.6.6" + self.CURRENT_VERSION = "0.6.10" self.COMMIT_SHA = get_env('COMMIT_SHA') self.EDITION = get_env('EDITION') self.DEPLOY_ENV = get_env('DEPLOY_ENV') @@ -116,6 +124,7 @@ class Config: self.LOG_FILE = get_env('LOG_FILE') self.LOG_FORMAT = get_env('LOG_FORMAT') self.LOG_DATEFORMAT = get_env('LOG_DATEFORMAT') + self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED') # The backend URL prefix of the console API. # used to concatenate the login authorization callback or notion integration callback. @@ -138,6 +147,10 @@ class Config: # Url is signed and has expiration time. self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL + # File Access Time specifies a time interval in seconds for the file to be accessed. + # The default value is 300 seconds. + self.FILES_ACCESS_TIMEOUT = int(get_env('FILES_ACCESS_TIMEOUT')) + # Your App secret key will be used for securely signing the session cookie # Make sure you are changing this key for your deployment with a strong key. # You can generate a strong key using `openssl rand -base64 42`. @@ -165,14 +178,17 @@ class Config: key: get_env(key) for key in ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE', 'DB_CHARSET'] } + self.SQLALCHEMY_DATABASE_URI_SCHEME = get_env('SQLALCHEMY_DATABASE_URI_SCHEME') db_extras = f"?client_encoding={db_credentials['DB_CHARSET']}" if db_credentials['DB_CHARSET'] else "" - self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}" + self.SQLALCHEMY_DATABASE_URI = f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}" self.SQLALCHEMY_ENGINE_OPTIONS = { 'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')), 'max_overflow': int(get_env('SQLALCHEMY_MAX_OVERFLOW')), - 'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')) + 'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')), + 'pool_pre_ping': get_bool_env('SQLALCHEMY_POOL_PRE_PING'), + 'connect_args': {'options': '-c timezone=UTC'}, } self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO') @@ -196,36 +212,51 @@ class Config: if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://') + # ------------------------ + # Code Execution Sandbox Configurations. + # ------------------------ + self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT') + self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY') + # ------------------------ # File Storage Configurations. # ------------------------ self.STORAGE_TYPE = get_env('STORAGE_TYPE') self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH') + + # S3 Storage settings self.S3_ENDPOINT = get_env('S3_ENDPOINT') self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME') self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY') self.S3_SECRET_KEY = get_env('S3_SECRET_KEY') self.S3_REGION = get_env('S3_REGION') self.S3_ADDRESS_STYLE = get_env('S3_ADDRESS_STYLE') + + # Azure Blob Storage settings self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME') self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY') self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME') self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL') - self.ALIYUN_OSS_BUCKET_NAME=get_env('ALIYUN_OSS_BUCKET_NAME') - self.ALIYUN_OSS_ACCESS_KEY=get_env('ALIYUN_OSS_ACCESS_KEY') - self.ALIYUN_OSS_SECRET_KEY=get_env('ALIYUN_OSS_SECRET_KEY') - self.ALIYUN_OSS_ENDPOINT=get_env('ALIYUN_OSS_ENDPOINT') - self.ALIYUN_OSS_REGION=get_env('ALIYUN_OSS_REGION') - self.ALIYUN_OSS_AUTH_VERSION=get_env('ALIYUN_OSS_AUTH_VERSION') + + # Aliyun Storage settings + self.ALIYUN_OSS_BUCKET_NAME = get_env('ALIYUN_OSS_BUCKET_NAME') + self.ALIYUN_OSS_ACCESS_KEY = get_env('ALIYUN_OSS_ACCESS_KEY') + self.ALIYUN_OSS_SECRET_KEY = get_env('ALIYUN_OSS_SECRET_KEY') + self.ALIYUN_OSS_ENDPOINT = get_env('ALIYUN_OSS_ENDPOINT') + self.ALIYUN_OSS_REGION = get_env('ALIYUN_OSS_REGION') + self.ALIYUN_OSS_AUTH_VERSION = get_env('ALIYUN_OSS_AUTH_VERSION') + + # Google Cloud Storage settings self.GOOGLE_STORAGE_BUCKET_NAME = get_env('GOOGLE_STORAGE_BUCKET_NAME') self.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 = get_env('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64') # ------------------------ # Vector Store Configurations. - # Currently, only support: qdrant, milvus, zilliz, weaviate, relyt + # Currently, only support: qdrant, milvus, zilliz, weaviate, relyt, pgvector # ------------------------ self.VECTOR_STORE = get_env('VECTOR_STORE') self.KEYWORD_STORE = get_env('KEYWORD_STORE') + # qdrant settings self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') @@ -261,6 +292,20 @@ class Config: self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD') self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE') + # pgvector settings + self.PGVECTOR_HOST = get_env('PGVECTOR_HOST') + self.PGVECTOR_PORT = get_env('PGVECTOR_PORT') + self.PGVECTOR_USER = get_env('PGVECTOR_USER') + self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD') + self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE') + + # tidb-vector settings + self.TIDB_VECTOR_HOST = get_env('TIDB_VECTOR_HOST') + self.TIDB_VECTOR_PORT = get_env('TIDB_VECTOR_PORT') + self.TIDB_VECTOR_USER = get_env('TIDB_VECTOR_USER') + self.TIDB_VECTOR_PASSWORD = get_env('TIDB_VECTOR_PASSWORD') + self.TIDB_VECTOR_DATABASE = get_env('TIDB_VECTOR_DATABASE') + # ------------------------ # Mail Configurations. # ------------------------ @@ -274,7 +319,8 @@ class Config: self.SMTP_USERNAME = get_env('SMTP_USERNAME') self.SMTP_PASSWORD = get_env('SMTP_PASSWORD') self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS') - + self.SMTP_OPPORTUNISTIC_TLS = get_bool_env('SMTP_OPPORTUNISTIC_TLS') + # ------------------------ # Workspace Configurations. # ------------------------ @@ -301,6 +347,23 @@ class Config: self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT')) self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT')) self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT')) + self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT') + + # RAG ETL Configurations. + self.ETL_TYPE = get_env('ETL_TYPE') + self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL') + self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY') + self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') + + # Indexing Configurations. + self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH') + + # Tool Configurations. + self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') + + self.WORKFLOW_MAX_EXECUTION_STEPS = int(get_env('WORKFLOW_MAX_EXECUTION_STEPS')) + self.WORKFLOW_MAX_EXECUTION_TIME = int(get_env('WORKFLOW_MAX_EXECUTION_TIME')) + self.WORKFLOW_CALL_MAX_DEPTH = int(get_env('WORKFLOW_CALL_MAX_DEPTH')) # Moderation in app Configurations. self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE')) @@ -352,18 +415,15 @@ class Config: self.HOSTED_FETCH_APP_TEMPLATES_MODE = get_env('HOSTED_FETCH_APP_TEMPLATES_MODE') self.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = get_env('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN') - self.ETL_TYPE = get_env('ETL_TYPE') - self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL') + # Model Load Balancing Configurations. + self.MODEL_LB_ENABLED = get_bool_env('MODEL_LB_ENABLED') + + # Platform Billing Configurations. self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED') - self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO') - self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT') - - self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT') - self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY') - - self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED') - self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') - - self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') + # ------------------------ + # Enterprise feature Configurations. + # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** + # ------------------------ self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED') + self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO') diff --git a/api/constants/languages.py b/api/constants/languages.py index bdfd8022a3..b4626cf51f 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -1,6 +1,6 @@ -languages = ['en-US', 'zh-Hans', 'zh-Hant', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN'] +languages = ['en-US', 'zh-Hans', 'zh-Hant', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN', 'pl-PL'] language_timezone_mapping = { 'en-US': 'America/New_York', @@ -16,6 +16,8 @@ language_timezone_mapping = { 'it-IT': 'Europe/Rome', 'uk-UA': 'Europe/Kyiv', 'vi-VN': 'Asia/Ho_Chi_Minh', + 'ro-RO': 'Europe/Bucharest', + 'pl-PL': 'Europe/Warsaw', } diff --git a/api/constants/recommended_apps.json b/api/constants/recommended_apps.json index 8a1ee808e4..68c913f80a 100644 --- a/api/constants/recommended_apps.json +++ b/api/constants/recommended_apps.json @@ -24,7 +24,8 @@ "description": "Welcome to your personalized Investment Analysis Copilot service, where we delve into the depths of stock analysis to provide you with comprehensive insights. \n", "is_listed": true, "position": 0, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -40,7 +41,8 @@ "description": "Code interpreter, clarifying the syntax and semantics of the code.", "is_listed": true, "position": 13, - "privacy_policy": "https://dify.ai" + "privacy_policy": "https://dify.ai", + "custom_disclaimer": null }, { "app": { @@ -56,7 +58,8 @@ "description": "Hello, I am your creative partner in bringing ideas to vivid life! I can assist you in creating stunning designs by leveraging abilities of DALL\u00b7E 3. ", "is_listed": true, "position": 4, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -72,7 +75,8 @@ "description": "Fully SEO Optimized Article including FAQs", "is_listed": true, "position": 1, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -88,7 +92,8 @@ "description": "Generate Flat Style Image", "is_listed": true, "position": 10, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -104,7 +109,8 @@ "description": "A multilingual translator that provides translation capabilities in multiple languages. Input the text you need to translate and select the target language.", "is_listed": true, "position": 10, - "privacy_policy": "https://dify.ai" + "privacy_policy": "https://dify.ai", + "custom_disclaimer": null }, { "app": { @@ -120,7 +126,8 @@ "description": "I am a YouTube Channel Data Analysis Copilot, I am here to provide expert data analysis tailored to your needs. ", "is_listed": true, "position": 2, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -136,7 +143,8 @@ "description": "Meeting minutes generator", "is_listed": true, "position": 0, - "privacy_policy": "https://dify.ai" + "privacy_policy": "https://dify.ai", + "custom_disclaimer": null }, { "app": { @@ -152,7 +160,8 @@ "description": "Tell me the main elements, I will generate a cyberpunk style image for you. ", "is_listed": true, "position": 10, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -168,7 +177,8 @@ "description": "Write SQL from natural language by pasting in your schema with the request.Please describe your query requirements in natural language and select the target database type.", "is_listed": true, "position": 13, - "privacy_policy": "https://dify.ai" + "privacy_policy": "https://dify.ai", + "custom_disclaimer": null }, { "app": { @@ -184,7 +194,8 @@ "description": "Welcome to your personalized travel service with Consultant! \ud83c\udf0d\u2708\ufe0f Ready to embark on a journey filled with adventure and relaxation? Let's dive into creating your unforgettable travel experience. ", "is_listed": true, "position": 3, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -200,7 +211,8 @@ "description": "I can answer your questions related to strategic marketing.", "is_listed": true, "position": 10, - "privacy_policy": "https://dify.ai" + "privacy_policy": "https://dify.ai", + "custom_disclaimer": null }, { "app": { @@ -216,7 +228,8 @@ "description": "A simulated front-end interviewer that tests the skill level of front-end development through questioning.", "is_listed": true, "position": 19, - "privacy_policy": "https://dify.ai" + "privacy_policy": "https://dify.ai", + "custom_disclaimer": null }, { "app": { @@ -232,7 +245,8 @@ "description": "I'm here to hear about your feature request about Dify and help you flesh it out further. What's on your mind?", "is_listed": true, "position": 6, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null } ] }, @@ -261,7 +275,8 @@ "description": "\u4e00\u4e2a\u6a21\u62df\u7684\u524d\u7aef\u9762\u8bd5\u5b98\uff0c\u901a\u8fc7\u63d0\u95ee\u7684\u65b9\u5f0f\u5bf9\u524d\u7aef\u5f00\u53d1\u7684\u6280\u80fd\u6c34\u5e73\u8fdb\u884c\u68c0\u9a8c\u3002", "is_listed": true, "position": 20, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -277,7 +292,8 @@ "description": "\u8f93\u5165\u76f8\u5173\u5143\u7d20\uff0c\u4e3a\u4f60\u751f\u6210\u6241\u5e73\u63d2\u753b\u98ce\u683c\u7684\u5c01\u9762\u56fe\u7247", "is_listed": true, "position": 10, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -293,7 +309,8 @@ "description": "\u4e00\u4e2a\u591a\u8bed\u8a00\u7ffb\u8bd1\u5668\uff0c\u63d0\u4f9b\u591a\u79cd\u8bed\u8a00\u7ffb\u8bd1\u80fd\u529b\uff0c\u8f93\u5165\u4f60\u9700\u8981\u7ffb\u8bd1\u7684\u6587\u672c\uff0c\u9009\u62e9\u76ee\u6807\u8bed\u8a00\u5373\u53ef\u3002\u63d0\u793a\u8bcd\u6765\u81ea\u5b9d\u7389\u3002", "is_listed": true, "position": 10, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -309,7 +326,8 @@ "description": "\u6211\u5c06\u5e2e\u52a9\u4f60\u628a\u81ea\u7136\u8bed\u8a00\u8f6c\u5316\u6210\u6307\u5b9a\u7684\u6570\u636e\u5e93\u67e5\u8be2 SQL \u8bed\u53e5\uff0c\u8bf7\u5728\u4e0b\u65b9\u8f93\u5165\u4f60\u9700\u8981\u67e5\u8be2\u7684\u6761\u4ef6\uff0c\u5e76\u9009\u62e9\u76ee\u6807\u6570\u636e\u5e93\u7c7b\u578b\u3002", "is_listed": true, "position": 12, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -325,7 +343,8 @@ "description": "\u9610\u660e\u4ee3\u7801\u7684\u8bed\u6cd5\u548c\u8bed\u4e49\u3002", "is_listed": true, "position": 2, - "privacy_policy": "https://dify.ai" + "privacy_policy": "https://dify.ai", + "custom_disclaimer": null }, { "app": { @@ -341,7 +360,8 @@ "description": "\u8f93\u5165\u76f8\u5173\u5143\u7d20\uff0c\u4e3a\u4f60\u751f\u6210\u8d5b\u535a\u670b\u514b\u98ce\u683c\u7684\u63d2\u753b", "is_listed": true, "position": 10, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -357,7 +377,8 @@ "description": "\u6211\u662f\u4e00\u540dSEO\u4e13\u5bb6\uff0c\u53ef\u4ee5\u6839\u636e\u60a8\u63d0\u4f9b\u7684\u6807\u9898\u3001\u5173\u952e\u8bcd\u3001\u76f8\u5173\u4fe1\u606f\u6765\u6279\u91cf\u751f\u6210SEO\u6587\u7ae0\u3002", "is_listed": true, "position": 10, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -373,7 +394,8 @@ "description": "\u5e2e\u4f60\u91cd\u65b0\u7ec4\u7ec7\u548c\u8f93\u51fa\u6df7\u4e71\u590d\u6742\u7684\u4f1a\u8bae\u7eaa\u8981\u3002", "is_listed": true, "position": 6, - "privacy_policy": "https://dify.ai" + "privacy_policy": "https://dify.ai", + "custom_disclaimer": null }, { "app": { @@ -389,7 +411,8 @@ "description": "\u6b22\u8fce\u4f7f\u7528\u60a8\u7684\u4e2a\u6027\u5316\u7f8e\u80a1\u6295\u8d44\u5206\u6790\u52a9\u624b\uff0c\u5728\u8fd9\u91cc\u6211\u4eec\u6df1\u5165\u7684\u8fdb\u884c\u80a1\u7968\u5206\u6790\uff0c\u4e3a\u60a8\u63d0\u4f9b\u5168\u9762\u7684\u6d1e\u5bdf\u3002", "is_listed": true, "position": 0, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -405,7 +428,8 @@ "description": "\u60a8\u597d\uff0c\u6211\u662f\u60a8\u7684\u521b\u610f\u4f19\u4f34\uff0c\u5c06\u5e2e\u52a9\u60a8\u5c06\u60f3\u6cd5\u751f\u52a8\u5730\u5b9e\u73b0\uff01\u6211\u53ef\u4ee5\u534f\u52a9\u60a8\u5229\u7528DALL\u00b7E 3\u7684\u80fd\u529b\u521b\u9020\u51fa\u4ee4\u4eba\u60ca\u53f9\u7684\u8bbe\u8ba1\u3002", "is_listed": true, "position": 4, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -421,7 +445,8 @@ "description": "\u7ffb\u8bd1\u4e13\u5bb6\uff1a\u63d0\u4f9b\u4e2d\u82f1\u6587\u4e92\u8bd1", "is_listed": true, "position": 4, - "privacy_policy": "https://dify.ai" + "privacy_policy": "https://dify.ai", + "custom_disclaimer": null }, { "app": { @@ -437,7 +462,8 @@ "description": "\u60a8\u7684\u79c1\u4eba\u5b66\u4e60\u5bfc\u5e08\uff0c\u5e2e\u60a8\u5236\u5b9a\u5b66\u4e60\u8ba1\u5212\u5e76\u8f85\u5bfc", "is_listed": true, "position": 26, - "privacy_policy": "https://dify.ai" + "privacy_policy": "https://dify.ai", + "custom_disclaimer": null }, { "app": { @@ -453,7 +479,8 @@ "description": "\u5e2e\u4f60\u64b0\u5199\u8bba\u6587\u6587\u732e\u7efc\u8ff0", "is_listed": true, "position": 7, - "privacy_policy": "https://dify.ai" + "privacy_policy": "https://dify.ai", + "custom_disclaimer": null }, { "app": { @@ -469,7 +496,8 @@ "description": "\u4f60\u597d\uff0c\u544a\u8bc9\u6211\u60a8\u60f3\u5206\u6790\u7684 YouTube \u9891\u9053\uff0c\u6211\u5c06\u4e3a\u60a8\u6574\u7406\u4e00\u4efd\u5b8c\u6574\u7684\u6570\u636e\u5206\u6790\u62a5\u544a\u3002", "is_listed": true, "position": 0, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null }, { "app": { @@ -485,7 +513,8 @@ "description": "\u6b22\u8fce\u4f7f\u7528\u60a8\u7684\u4e2a\u6027\u5316\u65c5\u884c\u670d\u52a1\u987e\u95ee\uff01\ud83c\udf0d\u2708\ufe0f \u51c6\u5907\u597d\u8e0f\u4e0a\u4e00\u6bb5\u5145\u6ee1\u5192\u9669\u4e0e\u653e\u677e\u7684\u65c5\u7a0b\u4e86\u5417\uff1f\u8ba9\u6211\u4eec\u4e00\u8d77\u6df1\u5165\u6253\u9020\u60a8\u96be\u5fd8\u7684\u65c5\u884c\u4f53\u9a8c\u5427\u3002", "is_listed": true, "position": 0, - "privacy_policy": null + "privacy_policy": null, + "custom_disclaimer": null } ] }, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 498557cd51..306b7384cf 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -37,9 +37,6 @@ from .billing import billing # Import datasets controllers from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing -# Import enterprise controllers -from .enterprise import enterprise_sso - # Import explore controllers from .explore import ( audio, @@ -57,4 +54,4 @@ from .explore import ( from .tag import tags # Import workspace controllers -from .workspace import account, members, model_providers, models, tool_providers, workspace +from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index aaa737f83a..028be5de54 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -48,6 +48,7 @@ class InsertExploreAppListApi(Resource): parser.add_argument('desc', type=str, location='json') parser.add_argument('copyright', type=str, location='json') parser.add_argument('privacy_policy', type=str, location='json') + parser.add_argument('custom_disclaimer', type=str, location='json') parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json') parser.add_argument('category', type=str, required=True, nullable=False, location='json') parser.add_argument('position', type=int, required=True, nullable=False, location='json') @@ -62,6 +63,7 @@ class InsertExploreAppListApi(Resource): desc = args['desc'] if args['desc'] else '' copy_right = args['copyright'] if args['copyright'] else '' privacy_policy = args['privacy_policy'] if args['privacy_policy'] else '' + custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else '' else: desc = site.description if site.description else \ args['desc'] if args['desc'] else '' @@ -69,6 +71,8 @@ class InsertExploreAppListApi(Resource): args['copyright'] if args['copyright'] else '' privacy_policy = site.privacy_policy if site.privacy_policy else \ args['privacy_policy'] if args['privacy_policy'] else '' + custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \ + args['custom_disclaimer'] if args['custom_disclaimer'] else '' recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() @@ -78,6 +82,7 @@ class InsertExploreAppListApi(Resource): description=desc, copyright=copy_right, privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer, language=args['language'], category=args['category'], position=args['position'] @@ -93,6 +98,7 @@ class InsertExploreAppListApi(Resource): recommended_app.description = desc recommended_app.copyright = copy_right recommended_app.privacy_policy = privacy_policy + recommended_app.custom_disclaimer = custom_disclaimer recommended_app.language = args['language'] recommended_app.category = args['category'] recommended_app.position = args['position'] diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 29d89ae460..51322c92d3 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -85,7 +85,7 @@ class ChatMessageTextApi(Resource): response = AudioService.transcript_tts( app_model=app_model, text=request.form['text'], - voice=request.form.get('voice'), + voice=request.form['voice'], streaming=False ) diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index b1abb38248..fbe42fbd2a 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -91,3 +91,9 @@ class DraftWorkflowNotExist(BaseHTTPException): error_code = 'draft_workflow_not_exist' description = "Draft workflow need to be initialized." code = 400 + + +class DraftWorkflowNotSync(BaseHTTPException): + error_code = 'draft_workflow_not_sync' + description = "Workflow graph might have been modified, please refresh and resubmit." + code = 400 diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 256824981e..592009fd88 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -23,6 +23,7 @@ def parse_app_site_args(): parser.add_argument('customize_domain', type=str, required=False, location='json') parser.add_argument('copyright', type=str, required=False, location='json') parser.add_argument('privacy_policy', type=str, required=False, location='json') + parser.add_argument('custom_disclaimer', type=str, required=False, location='json') parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'], required=False, location='json') @@ -56,6 +57,7 @@ class AppSite(Resource): 'customize_domain', 'copyright', 'privacy_policy', + 'custom_disclaimer', 'customize_token_strategy', 'prompt_public' ]: diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index b88a9b7fcc..641997f3f3 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -7,7 +7,7 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api -from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist +from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required @@ -20,6 +20,7 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.model import App, AppMode from services.app_generate_service import AppGenerateService +from services.errors.app import WorkflowHashNotEqualError from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) @@ -59,6 +60,7 @@ class DraftWorkflowApi(Resource): parser = reqparse.RequestParser() parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') parser.add_argument('features', type=dict, required=True, nullable=False, location='json') + parser.add_argument('hash', type=str, required=False, location='json') args = parser.parse_args() elif 'text/plain' in content_type: try: @@ -71,7 +73,8 @@ class DraftWorkflowApi(Resource): args = { 'graph': data.get('graph'), - 'features': data.get('features') + 'features': data.get('features'), + 'hash': data.get('hash') } except json.JSONDecodeError: return {'message': 'Invalid JSON data'}, 400 @@ -79,15 +82,21 @@ class DraftWorkflowApi(Resource): abort(415) workflow_service = WorkflowService() - workflow = workflow_service.sync_draft_workflow( - app_model=app_model, - graph=args.get('graph'), - features=args.get('features'), - account=current_user - ) + + try: + workflow = workflow_service.sync_draft_workflow( + app_model=app_model, + graph=args.get('graph'), + features=args.get('features'), + unique_hash=args.get('hash'), + account=current_user + ) + except WorkflowHashNotEqualError: + raise DraftWorkflowNotSync() return { "result": "success", + "hash": workflow.unique_hash, "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) } @@ -128,6 +137,71 @@ class AdvancedChatDraftWorkflowRunApi(Resource): logging.exception("internal server error.") raise InternalServerError() +class AdvancedChatDraftRunIterationNodeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + def post(self, app_model: App, node_id: str): + """ + Run draft workflow iteration node + """ + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, location='json') + args = parser.parse_args() + + try: + response = AppGenerateService.generate_single_iteration( + app_model=app_model, + user=current_user, + node_id=node_id, + args=args, + streaming=True + ) + + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + +class WorkflowDraftRunIterationNodeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + def post(self, app_model: App, node_id: str): + """ + Run draft workflow iteration node + """ + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, location='json') + args = parser.parse_args() + + try: + response = AppGenerateService.generate_single_iteration( + app_model=app_model, + user=current_user, + node_id=node_id, + args=args, + streaming=True + ) + + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() class DraftWorkflowRunApi(Resource): @setup_required @@ -317,6 +391,8 @@ api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced- api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') api.add_resource(WorkflowTaskStopApi, '/apps//workflow-runs/tasks//stop') api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') +api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps//advanced-chat/workflows/draft/iteration/nodes//run') +api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps//workflows/draft/iteration/nodes//run') api.add_resource(PublishedWorkflowApi, '/apps//workflows/publish') api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 40ded54120..72c4c09055 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -476,13 +476,13 @@ class DatasetRetrievalSettingApi(Resource): @account_initialization_required def get(self): vector_type = current_app.config['VECTOR_STORE'] - if vector_type == 'milvus' or vector_type == 'pgvecto_rs' or vector_type == 'relyt': + if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs", 'tidb_vector'}: return { 'retrieval_method': [ 'semantic_search' ] } - elif vector_type == 'qdrant' or vector_type == 'weaviate': + elif vector_type in {"qdrant", "weaviate"}: return { 'retrieval_method': [ 'semantic_search', 'full_text_search', 'hybrid_search' @@ -497,14 +497,13 @@ class DatasetRetrievalSettingMockApi(Resource): @login_required @account_initialization_required def get(self, vector_type): - - if vector_type == 'milvus' or vector_type == 'relyt': + if vector_type in {'milvus', 'relyt', 'pgvector', 'tidb_vector'}: return { 'retrieval_method': [ 'semantic_search' ] } - elif vector_type == 'qdrant' or vector_type == 'weaviate': + elif vector_type in {'qdrant', 'weaviate'}: return { 'retrieval_method': [ 'semantic_search', 'full_text_search', 'hybrid_search' diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 9dedcefe0f..5498d22e78 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,10 +1,12 @@ import logging +from argparse import ArgumentTypeError from datetime import datetime, timezone from flask import request from flask_login import current_user from flask_restful import Resource, fields, marshal, marshal_with, reqparse from sqlalchemy import asc, desc +from transformers.hf_argparser import string_to_bool from werkzeug.exceptions import Forbidden, NotFound import services @@ -141,7 +143,11 @@ class DatasetDocumentListApi(Resource): limit = request.args.get('limit', default=20, type=int) search = request.args.get('keyword', default=None, type=str) sort = request.args.get('sort', default='-created_at', type=str) - fetch = request.args.get('fetch', default=False, type=bool) + # "yes", "true", "t", "y", "1" convert to True, while others convert to False. + try: + fetch = string_to_bool(request.args.get('fetch', default='false')) + except (ArgumentTypeError, ValueError, Exception) as e: + fetch = False dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound('Dataset not found.') @@ -924,6 +930,28 @@ class DocumentRetryApi(DocumentResource): return {'result': 'success'}, 204 +class DocumentRenameApi(DocumentResource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(document_fields) + def post(self, dataset_id, document_id): + # The role of the current user in the ta table must be admin or owner + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, nullable=False, location='json') + args = parser.parse_args() + + try: + document = DocumentService.rename_document(dataset_id, document_id, args['name']) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError('Cannot delete document during indexing.') + + return document + + api.add_resource(GetProcessRuleApi, '/datasets/process-rule') api.add_resource(DatasetDocumentListApi, '/datasets//documents') @@ -950,3 +978,5 @@ api.add_resource(DocumentStatusApi, api.add_resource(DocumentPauseApi, '/datasets//documents//processing/pause') api.add_resource(DocumentRecoverApi, '/datasets//documents//processing/resume') api.add_resource(DocumentRetryApi, '/datasets//retry') +api.add_resource(DocumentRenameApi, + '/datasets//documents//rename') diff --git a/api/controllers/console/enterprise/enterprise_sso.py b/api/controllers/console/enterprise/enterprise_sso.py deleted file mode 100644 index f6a2897d5a..0000000000 --- a/api/controllers/console/enterprise/enterprise_sso.py +++ /dev/null @@ -1,59 +0,0 @@ -from flask import current_app, redirect -from flask_restful import Resource, reqparse - -from controllers.console import api -from controllers.console.setup import setup_required -from services.enterprise.enterprise_sso_service import EnterpriseSSOService - - -class EnterpriseSSOSamlLogin(Resource): - - @setup_required - def get(self): - return EnterpriseSSOService.get_sso_saml_login() - - -class EnterpriseSSOSamlAcs(Resource): - - @setup_required - def post(self): - parser = reqparse.RequestParser() - parser.add_argument('SAMLResponse', type=str, required=True, location='form') - args = parser.parse_args() - saml_response = args['SAMLResponse'] - - try: - token = EnterpriseSSOService.post_sso_saml_acs(saml_response) - return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') - except Exception as e: - return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') - - -class EnterpriseSSOOidcLogin(Resource): - - @setup_required - def get(self): - return EnterpriseSSOService.get_sso_oidc_login() - - -class EnterpriseSSOOidcCallback(Resource): - - @setup_required - def get(self): - parser = reqparse.RequestParser() - parser.add_argument('state', type=str, required=True, location='args') - parser.add_argument('code', type=str, required=True, location='args') - parser.add_argument('oidc-state', type=str, required=True, location='cookies') - args = parser.parse_args() - - try: - token = EnterpriseSSOService.get_sso_oidc_callback(args) - return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') - except Exception as e: - return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') - - -api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login') -api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs') -api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login') -api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback') diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index f03663f1a2..d869cd38ed 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -76,7 +76,7 @@ class ChatTextApi(InstalledAppResource): response = AudioService.transcript_tts( app_model=app_model, text=request.form['text'], - voice=request.form.get('voice'), + voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=False ) return {'data': response.data.decode('latin1')} diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 2787b7cdba..6e10e2ec92 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -21,6 +21,7 @@ recommended_app_fields = { 'description': fields.String(attribute='description'), 'copyright': fields.String, 'privacy_policy': fields.String, + 'custom_disclaimer': fields.String, 'category': fields.String, 'position': fields.Integer, 'is_listed': fields.Boolean diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 325652a447..44d9d67522 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,24 +1,28 @@ from flask_login import current_user from flask_restful import Resource -from services.enterprise.enterprise_feature_service import EnterpriseFeatureService +from libs.login import login_required from services.feature_service import FeatureService from . import api -from .wraps import cloud_utm_record +from .setup import setup_required +from .wraps import account_initialization_required, cloud_utm_record class FeatureApi(Resource): + @setup_required + @login_required + @account_initialization_required @cloud_utm_record def get(self): return FeatureService.get_features(current_user.current_tenant_id).dict() -class EnterpriseFeatureApi(Resource): +class SystemFeatureApi(Resource): def get(self): - return EnterpriseFeatureService.get_enterprise_features().dict() + return FeatureService.get_system_features().dict() api.add_resource(FeatureApi, '/features') -api.add_resource(EnterpriseFeatureApi, '/enterprise-features') +api.add_resource(SystemFeatureApi, '/system-features') diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index a50e4c41a8..faf36c4f40 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -17,13 +17,19 @@ class VersionApi(Resource): args = parser.parse_args() check_update_url = current_app.config['CHECK_UPDATE_URL'] - if not check_update_url: - return { - 'version': '0.0.0', - 'release_date': '', - 'release_notes': '', - 'can_auto_update': False + result = { + 'version': current_app.config['CURRENT_VERSION'], + 'release_date': '', + 'release_notes': '', + 'can_auto_update': False, + 'features': { + 'can_replace_logo': current_app.config['CAN_REPLACE_LOGO'], + 'model_load_balancing_enabled': current_app.config['MODEL_LB_ENABLED'] } + } + + if not check_update_url: + return result try: response = requests.get(check_update_url, { @@ -31,20 +37,15 @@ class VersionApi(Resource): }) except Exception as error: logging.warning("Check update version error: {}.".format(str(error))) - return { - 'version': args.get('current_version'), - 'release_date': '', - 'release_notes': '', - 'can_auto_update': False - } + result['version'] = args.get('current_version') + return result content = json.loads(response.content) - return { - 'version': content['version'], - 'release_date': content['releaseDate'], - 'release_notes': content['releaseNotes'], - 'can_auto_update': content['canAutoUpdate'] - } + result['version'] = content['version'] + result['release_date'] = content['releaseDate'] + result['release_notes'] = content['releaseNotes'] + result['can_auto_update'] = content['canAutoUpdate'] + return result api.add_resource(VersionApi, '/version') diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py new file mode 100644 index 0000000000..50514e39f6 --- /dev/null +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -0,0 +1,106 @@ +from flask_restful import Resource, reqparse +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from libs.login import current_user, login_required +from models.account import TenantAccountRole +from services.model_load_balancing_service import ModelLoadBalancingService + + +class LoadBalancingCredentialsValidateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): + raise Forbidden() + + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('model', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=[mt.value for mt in ModelType], location='json') + parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + # validate model load balancing credentials + model_load_balancing_service = ModelLoadBalancingService() + + result = True + error = None + + try: + model_load_balancing_service.validate_load_balancing_credentials( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'], + credentials=args['credentials'] + ) + except CredentialsValidateFailedError as ex: + result = False + error = str(ex) + + response = {'result': 'success' if result else 'error'} + + if not result: + response['error'] = error + + return response + + +class LoadBalancingConfigCredentialsValidateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str, config_id: str): + if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): + raise Forbidden() + + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('model', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=[mt.value for mt in ModelType], location='json') + parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + # validate model load balancing config credentials + model_load_balancing_service = ModelLoadBalancingService() + + result = True + error = None + + try: + model_load_balancing_service.validate_load_balancing_credentials( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'], + credentials=args['credentials'], + config_id=config_id, + ) + except CredentialsValidateFailedError as ex: + result = False + error = str(ex) + + response = {'result': 'success' if result else 'error'} + + if not result: + response['error'] = error + + return response + + +# Load Balancing Config +api.add_resource(LoadBalancingCredentialsValidateApi, + '/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate') + +api.add_resource(LoadBalancingConfigCredentialsValidateApi, + '/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate') diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 23239b1902..76ae6a4ab9 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from libs.login import login_required from models.account import TenantAccountRole +from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService @@ -104,21 +105,56 @@ class ModelProviderModelApi(Resource): parser.add_argument('model', type=str, required=True, nullable=False, location='json') parser.add_argument('model_type', type=str, required=True, nullable=False, choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json') + parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json') + parser.add_argument('config_from', type=str, required=False, nullable=True, location='json') args = parser.parse_args() - model_provider_service = ModelProviderService() + model_load_balancing_service = ModelLoadBalancingService() - try: - model_provider_service.save_model_credentials( + if ('load_balancing' in args and args['load_balancing'] and + 'enabled' in args['load_balancing'] and args['load_balancing']['enabled']): + if 'configs' not in args['load_balancing']: + raise ValueError('invalid load balancing configs') + + # save load balancing configs + model_load_balancing_service.update_load_balancing_configs( tenant_id=tenant_id, provider=provider, model=args['model'], model_type=args['model_type'], - credentials=args['credentials'] + configs=args['load_balancing']['configs'] ) - except CredentialsValidateFailedError as ex: - raise ValueError(str(ex)) + + # enable load balancing + model_load_balancing_service.enable_model_load_balancing( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'] + ) + else: + # disable load balancing + model_load_balancing_service.disable_model_load_balancing( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'] + ) + + if args.get('config_from', '') != 'predefined-model': + model_provider_service = ModelProviderService() + + try: + model_provider_service.save_model_credentials( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'], + credentials=args['credentials'] + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) return {'result': 'success'}, 200 @@ -170,11 +206,73 @@ class ModelProviderModelCredentialApi(Resource): model=args['model'] ) + model_load_balancing_service = ModelLoadBalancingService() + is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'] + ) + return { - "credentials": credentials + "credentials": credentials, + "load_balancing": { + "enabled": is_load_balancing_enabled, + "configs": load_balancing_configs + } } +class ModelProviderModelEnableApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def patch(self, provider: str): + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('model', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=[mt.value for mt in ModelType], location='json') + args = parser.parse_args() + + model_provider_service = ModelProviderService() + model_provider_service.enable_model( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'] + ) + + return {'result': 'success'} + + +class ModelProviderModelDisableApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def patch(self, provider: str): + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('model', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=[mt.value for mt in ModelType], location='json') + args = parser.parse_args() + + model_provider_service = ModelProviderService() + model_provider_service.disable_model( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'] + ) + + return {'result': 'success'} + + class ModelProviderModelValidateApi(Resource): @setup_required @@ -259,6 +357,10 @@ class ModelProviderAvailableModelApi(Resource): api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers//models') +api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers//models/enable', + endpoint='model-provider-model-enable') +api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers//models/disable', + endpoint='model-provider-model-disable') api.add_resource(ModelProviderModelCredentialApi, '/workspaces/current/model-providers//models/credentials') api.add_resource(ModelProviderModelValidateApi, diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index b02008339e..a911e9b2cb 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -9,8 +9,13 @@ from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.model_runtime.utils.encoders import jsonable_encoder +from libs.helper import alphanumeric, uuid_value from libs.login import login_required -from services.tools_manage_service import ToolManageService +from services.tools.api_tools_manage_service import ApiToolManageService +from services.tools.builtin_tools_manage_service import BuiltinToolManageService +from services.tools.tool_labels_service import ToolLabelsService +from services.tools.tools_manage_service import ToolCommonService +from services.tools.workflow_tools_manage_service import WorkflowToolManageService class ToolProviderListApi(Resource): @@ -21,7 +26,11 @@ class ToolProviderListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return ToolManageService.list_tool_providers(user_id, tenant_id) + req = reqparse.RequestParser() + req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args') + args = req.parse_args() + + return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None)) class ToolBuiltinProviderListToolsApi(Resource): @setup_required @@ -31,7 +40,7 @@ class ToolBuiltinProviderListToolsApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder(ToolManageService.list_builtin_tool_provider_tools( + return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools( user_id, tenant_id, provider, @@ -48,7 +57,7 @@ class ToolBuiltinProviderDeleteApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return ToolManageService.delete_builtin_tool_provider( + return BuiltinToolManageService.delete_builtin_tool_provider( user_id, tenant_id, provider, @@ -70,7 +79,7 @@ class ToolBuiltinProviderUpdateApi(Resource): args = parser.parse_args() - return ToolManageService.update_builtin_tool_provider( + return BuiltinToolManageService.update_builtin_tool_provider( user_id, tenant_id, provider, @@ -85,7 +94,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return ToolManageService.get_builtin_tool_provider_credentials( + return BuiltinToolManageService.get_builtin_tool_provider_credentials( user_id, tenant_id, provider, @@ -94,35 +103,10 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): - icon_bytes, mimetype = ToolManageService.get_builtin_tool_provider_icon(provider) + icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider) icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE')) return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) -class ToolModelProviderIconApi(Resource): - @setup_required - def get(self, provider): - icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider) - return send_file(io.BytesIO(icon_bytes), mimetype=mimetype) - -class ToolModelProviderListToolsApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self): - user_id = current_user.id - tenant_id = current_user.current_tenant_id - - parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='args') - - args = parser.parse_args() - - return jsonable_encoder(ToolManageService.list_model_tool_provider_tools( - user_id, - tenant_id, - args['provider'], - )) - class ToolApiProviderAddApi(Resource): @setup_required @login_required @@ -141,10 +125,12 @@ class ToolApiProviderAddApi(Resource): parser.add_argument('provider', type=str, required=True, nullable=False, location='json') parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json') + parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[]) + parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json') args = parser.parse_args() - return ToolManageService.create_api_tool_provider( + return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, args['provider'], @@ -153,6 +139,8 @@ class ToolApiProviderAddApi(Resource): args['schema_type'], args['schema'], args.get('privacy_policy', ''), + args.get('custom_disclaimer', ''), + args.get('labels', []), ) class ToolApiProviderGetRemoteSchemaApi(Resource): @@ -166,7 +154,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): args = parser.parse_args() - return ToolManageService.get_api_tool_provider_remote_schema( + return ApiToolManageService.get_api_tool_provider_remote_schema( current_user.id, current_user.current_tenant_id, args['url'], @@ -186,7 +174,7 @@ class ToolApiProviderListToolsApi(Resource): args = parser.parse_args() - return jsonable_encoder(ToolManageService.list_api_tool_provider_tools( + return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools( user_id, tenant_id, args['provider'], @@ -211,10 +199,12 @@ class ToolApiProviderUpdateApi(Resource): parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json') parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json') + parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') + parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json') args = parser.parse_args() - return ToolManageService.update_api_tool_provider( + return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, args['provider'], @@ -224,6 +214,8 @@ class ToolApiProviderUpdateApi(Resource): args['schema_type'], args['schema'], args['privacy_policy'], + args['custom_disclaimer'], + args.get('labels', []), ) class ToolApiProviderDeleteApi(Resource): @@ -243,7 +235,7 @@ class ToolApiProviderDeleteApi(Resource): args = parser.parse_args() - return ToolManageService.delete_api_tool_provider( + return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, args['provider'], @@ -263,7 +255,7 @@ class ToolApiProviderGetApi(Resource): args = parser.parse_args() - return ToolManageService.get_api_tool_provider( + return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, args['provider'], @@ -274,7 +266,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @login_required @account_initialization_required def get(self, provider): - return ToolManageService.list_builtin_provider_credentials_schema(provider) + return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) class ToolApiProviderSchemaApi(Resource): @setup_required @@ -287,7 +279,7 @@ class ToolApiProviderSchemaApi(Resource): args = parser.parse_args() - return ToolManageService.parser_api_schema( + return ApiToolManageService.parser_api_schema( schema=args['schema'], ) @@ -307,7 +299,7 @@ class ToolApiProviderPreviousTestApi(Resource): args = parser.parse_args() - return ToolManageService.test_api_tool_preview( + return ApiToolManageService.test_api_tool_preview( current_user.current_tenant_id, args['provider_name'] if args['provider_name'] else '', args['tool_name'], @@ -317,6 +309,153 @@ class ToolApiProviderPreviousTestApi(Resource): args['schema'], ) +class ToolWorkflowProviderCreateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + if not current_user.is_admin_or_owner: + raise Forbidden() + + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + reqparser = reqparse.RequestParser() + reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json') + reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') + reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') + reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') + reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') + reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') + reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') + reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') + + args = reqparser.parse_args() + + return WorkflowToolManageService.create_workflow_tool( + user_id, + tenant_id, + args['workflow_app_id'], + args['name'], + args['label'], + args['icon'], + args['description'], + args['parameters'], + args['privacy_policy'], + args.get('labels', []), + ) + +class ToolWorkflowProviderUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + if not current_user.is_admin_or_owner: + raise Forbidden() + + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + reqparser = reqparse.RequestParser() + reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') + reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') + reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') + reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') + reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') + reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') + reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') + reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') + + args = reqparser.parse_args() + + if not args['workflow_tool_id']: + raise ValueError('incorrect workflow_tool_id') + + return WorkflowToolManageService.update_workflow_tool( + user_id, + tenant_id, + args['workflow_tool_id'], + args['name'], + args['label'], + args['icon'], + args['description'], + args['parameters'], + args['privacy_policy'], + args.get('labels', []), + ) + +class ToolWorkflowProviderDeleteApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + if not current_user.is_admin_or_owner: + raise Forbidden() + + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + reqparser = reqparse.RequestParser() + reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') + + args = reqparser.parse_args() + + return WorkflowToolManageService.delete_workflow_tool( + user_id, + tenant_id, + args['workflow_tool_id'], + ) + +class ToolWorkflowProviderGetApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args') + parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args') + + args = parser.parse_args() + + if args.get('workflow_tool_id'): + tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( + user_id, + tenant_id, + args['workflow_tool_id'], + ) + elif args.get('workflow_app_id'): + tool = WorkflowToolManageService.get_workflow_tool_by_app_id( + user_id, + tenant_id, + args['workflow_app_id'], + ) + else: + raise ValueError('incorrect workflow_tool_id or workflow_app_id') + + return jsonable_encoder(tool) + +class ToolWorkflowProviderListToolApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args') + + args = parser.parse_args() + + return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools( + user_id, + tenant_id, + args['workflow_tool_id'], + )) + class ToolBuiltinListApi(Resource): @setup_required @login_required @@ -325,7 +464,7 @@ class ToolBuiltinListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_builtin_tools( + return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( user_id, tenant_id, )]) @@ -338,20 +477,43 @@ class ToolApiListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_api_tools( + return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools( user_id, tenant_id, )]) + +class ToolWorkflowListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( + user_id, + tenant_id, + )]) + +class ToolLabelsApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + return jsonable_encoder(ToolLabelsService.list_tool_labels()) + +# tool provider api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') + +# builtin tool provider api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools') api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete') api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update') api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin//credentials') api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema') api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon') -api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model//icon') -api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools') + +# api tool provider api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') @@ -361,5 +523,15 @@ api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/g api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre') +# workflow tool provider +api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create') +api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update') +api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete') +api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get') +api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools') + api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin') -api.add_resource(ToolApiListApi, '/workspaces/current/tools/api') \ No newline at end of file +api.add_resource(ToolApiListApi, '/workspaces/current/tools/api') +api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow') + +api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels') \ No newline at end of file diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index cd72872b62..7a11a45ae8 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -161,13 +161,13 @@ class CustomConfigWorkspaceApi(Resource): parser.add_argument('replace_webapp_logo', type=str, location='json') args = parser.parse_args() + tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() + custom_config_dict = { 'remove_webapp_brand': args['remove_webapp_brand'], - 'replace_webapp_logo': args['replace_webapp_logo'], + 'replace_webapp_logo': args['replace_webapp_logo'] if args['replace_webapp_logo'] is not None else tenant.custom_config_dict.get('replace_webapp_logo') , } - tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() - tenant.custom_config_dict = custom_config_dict db.session.commit() diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 08f382b0a7..c8b44cfa38 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -97,7 +97,7 @@ class MessageListApi(Resource): class MessageFeedbackApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) @@ -114,7 +114,7 @@ class MessageFeedbackApi(Resource): class MessageSuggestedApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)) def get(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) app_mode = AppMode.value_of(app_model.mode) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 8ae81531ae..819512edf0 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -8,7 +8,7 @@ from flask import current_app, request from flask_login import user_logged_in from flask_restful import Resource from pydantic import BaseModel -from werkzeug.exceptions import Forbidden, NotFound, Unauthorized +from werkzeug.exceptions import Forbidden, Unauthorized from extensions.ext_database import db from libs.login import _get_user @@ -39,17 +39,17 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio app_model = db.session.query(App).filter(App.id == api_token.app_id).first() if not app_model: - raise NotFound() + raise Forbidden("The app no longer exists.") if app_model.status != 'normal': - raise NotFound() + raise Forbidden("The app's status is abnormal.") if not app_model.enable_api: - raise NotFound() + raise Forbidden("The app's API service has been disabled.") tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() if tenant.status == TenantStatus.ARCHIVE: - raise NotFound() + raise Forbidden("The workspace's status is archived.") kwargs['app_model'] = app_model diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index b6d46d4081..aa19bdc034 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -6,4 +6,4 @@ bp = Blueprint('web', __name__, url_prefix='/api') api = ExternalApi(bp) -from . import app, audio, completion, conversation, file, message, passport, saved_message, site, workflow +from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 2586f2e6ec..91d9015c33 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,14 +1,10 @@ -import json - from flask import current_app from flask_restful import fields, marshal_with from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource -from extensions.ext_database import db -from models.model import App, AppMode, AppModelConfig -from models.tools import ApiToolProvider +from models.model import App, AppMode from services.app_service import AppService diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index e0074c452f..ca6d774e9d 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -74,7 +74,7 @@ class TextApi(WebApiResource): app_model=app_model, text=request.form['text'], end_user=end_user.external_user_id, - voice=request.form.get('voice'), + voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=False ) diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 390e3fe7d1..bc87f51051 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -115,3 +115,9 @@ class UnsupportedFileTypeError(BaseHTTPException): error_code = 'unsupported_file_type' description = "File type not allowed." code = 415 + + +class WebSSOAuthRequiredError(BaseHTTPException): + error_code = 'web_sso_auth_required' + description = "Web SSO authentication required." + code = 401 diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py new file mode 100644 index 0000000000..65842d78c6 --- /dev/null +++ b/api/controllers/web/feature.py @@ -0,0 +1,12 @@ +from flask_restful import Resource + +from controllers.web import api +from services.feature_service import FeatureService + + +class SystemFeatureApi(Resource): + def get(self): + return FeatureService.get_system_features().dict() + + +api.add_resource(SystemFeatureApi, '/system-features') diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 92b28d8125..ccc8683a79 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -5,14 +5,21 @@ from flask_restful import Resource from werkzeug.exceptions import NotFound, Unauthorized from controllers.web import api +from controllers.web.error import WebSSOAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.feature_service import FeatureService class PassportResource(Resource): """Base resource for passport.""" def get(self): + + system_features = FeatureService.get_system_features() + if system_features.sso_enforced_for_web: + raise WebSSOAuthRequiredError() + app_code = request.headers.get('X-App-Code') if app_code is None: raise Unauthorized('X-App-Code header is missing.') @@ -28,7 +35,7 @@ class PassportResource(Resource): app_model = db.session.query(App).filter(App.id == site.app_id).first() if not app_model or app_model.status != 'normal' or not app_model.enable_site: raise NotFound() - + end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, @@ -36,6 +43,7 @@ class PassportResource(Resource): is_anonymous=True, session_id=generate_session_id(), ) + db.session.add(end_user) db.session.commit() @@ -53,8 +61,10 @@ class PassportResource(Resource): 'access_token': tk, } + api.add_resource(PassportResource, '/passport') + def generate_session_id(): """ Generate a unique session ID. diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 49b0a8bfc0..a084b56b08 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -31,6 +31,7 @@ class AppSiteApi(WebApiResource): 'description': fields.String, 'copyright': fields.String, 'privacy_policy': fields.String, + 'custom_disclaimer': fields.String, 'default_language': fields.String, 'prompt_public': fields.Boolean } diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index bdaa476f34..f5ab49d7e1 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -2,11 +2,13 @@ from functools import wraps from flask import request from flask_restful import Resource -from werkzeug.exceptions import NotFound, Unauthorized +from werkzeug.exceptions import BadRequest, NotFound, Unauthorized +from controllers.web.error import WebSSOAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.feature_service import FeatureService def validate_jwt_token(view=None): @@ -21,34 +23,60 @@ def validate_jwt_token(view=None): return decorator(view) return decorator + def decode_jwt_token(): - auth_header = request.headers.get('Authorization') - if auth_header is None: - raise Unauthorized('Authorization header is missing.') + system_features = FeatureService.get_system_features() - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') - - auth_scheme, tk = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() + try: + auth_header = request.headers.get('Authorization') + if auth_header is None: + raise Unauthorized('Authorization header is missing.') - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') - decoded = PassportService().verify(tk) - app_code = decoded.get('app_code') - app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() - site = db.session.query(Site).filter(Site.code == app_code).first() - if not app_model: - raise NotFound() - if not app_code or not site: - raise Unauthorized('Site URL is no longer valid.') - if app_model.enable_site is False: - raise Unauthorized('Site is disabled.') - end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() - if not end_user: - raise NotFound() + if ' ' not in auth_header: + raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + + auth_scheme, tk = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + + if auth_scheme != 'bearer': + raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + decoded = PassportService().verify(tk) + app_code = decoded.get('app_code') + app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() + site = db.session.query(Site).filter(Site.code == app_code).first() + if not app_model: + raise NotFound() + if not app_code or not site: + raise BadRequest('Site URL is no longer valid.') + if app_model.enable_site is False: + raise BadRequest('Site is disabled.') + end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() + if not end_user: + raise NotFound() + + _validate_web_sso_token(decoded, system_features) + + return app_model, end_user + except Unauthorized as e: + if system_features.sso_enforced_for_web: + raise WebSSOAuthRequiredError() + + raise Unauthorized(e.description) + + +def _validate_web_sso_token(decoded, system_features): + # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login + if system_features.sso_enforced_for_web: + source = decoded.get('token_source') + if not source or source != 'sso': + raise WebSSOAuthRequiredError() + + # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login + if not system_features.sso_enforced_for_web: + source = decoded.get('token_source') + if source and source == 'sso': + raise Unauthorized('sso token expired.') - return app_model, end_user class WebApiResource(Resource): method_decorators = [validate_jwt_token] diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 485633cab1..d228a3ac29 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -39,6 +39,7 @@ from core.tools.entities.tool_entities import ( from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool from core.tools.tool_manager import ToolManager +from core.tools.utils.tool_parameter_converter import ToolParameterConverter from extensions.ext_database import db from models.model import Conversation, Message, MessageAgentThought from models.tools import ToolConversationVariables @@ -128,6 +129,8 @@ class BaseAgentRunner(AppRunner): self.files = application_generate_entity.files else: self.files = [] + self.query = None + self._current_thoughts: list[PromptMessage] = [] def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ -> AgentChatAppGenerateEntity: @@ -165,6 +168,7 @@ class BaseAgentRunner(AppRunner): tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, + invoke_from=self.application_generate_entity.invoke_from ) tool_entity.load_variables(self.variables_pool) @@ -183,21 +187,11 @@ class BaseAgentRunner(AppRunner): if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = 'string' + parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) enum = [] - if parameter.type == ToolParameter.ToolParameterType.STRING: - parameter_type = 'string' - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - parameter_type = 'boolean' - elif parameter.type == ToolParameter.ToolParameterType.NUMBER: - parameter_type = 'number' - elif parameter.type == ToolParameter.ToolParameterType.SELECT: - for option in parameter.options: - enum.append(option.value) - parameter_type = 'string' - else: - raise ValueError(f"parameter type {parameter.type} is not supported") - + if parameter.type == ToolParameter.ToolParameterType.SELECT: + enum = [option.value for option in parameter.options] + message_tool.parameters['properties'][parameter.name] = { "type": parameter_type, "description": parameter.llm_description or '', @@ -278,20 +272,10 @@ class BaseAgentRunner(AppRunner): if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = 'string' + parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) enum = [] - if parameter.type == ToolParameter.ToolParameterType.STRING: - parameter_type = 'string' - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - parameter_type = 'boolean' - elif parameter.type == ToolParameter.ToolParameterType.NUMBER: - parameter_type = 'number' - elif parameter.type == ToolParameter.ToolParameterType.SELECT: - for option in parameter.options: - enum.append(option.value) - parameter_type = 'string' - else: - raise ValueError(f"parameter type {parameter.type} is not supported") + if parameter.type == ToolParameter.ToolParameterType.SELECT: + enum = [option.value for option in parameter.options] prompt_tool.parameters['properties'][parameter.name] = { "type": parameter_type, @@ -463,7 +447,7 @@ class BaseAgentRunner(AppRunner): for message in messages: if message.id == self.message.id: continue - + result.append(self.organize_agent_user_prompt(message)) agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if agent_thoughts: diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 12554f42b3..40cfb20d0b 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool.tool import Tool from core.tools.tool_engine import ToolEngine @@ -121,7 +122,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): raise ValueError("failed to invoke llm") usage_dict = {} - react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks) + react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( agent_response='', thought='', @@ -189,7 +190,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): if not scratchpad.action: # failed to extract action, return final answer directly - final_answer = scratchpad.agent_response or '' + final_answer = '' else: if scratchpad.action.action_name.lower() == "final answer": # action is final answer, return final answer directly @@ -373,7 +374,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): return message - def _organize_historic_prompt_messages(self) -> list[PromptMessage]: + def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ organize historic prompt messages """ @@ -381,6 +382,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): scratchpad: list[AgentScratchpadUnit] = [] current_scratchpad: AgentScratchpadUnit = None + self.history_prompt_messages = AgentHistoryPromptTransform( + model_config=self.model_config, + prompt_messages=current_session_messages or [], + history_messages=self.history_prompt_messages, + memory=self.memory + ).get_prompt() + for message in self.history_prompt_messages: if isinstance(message, AssistantPromptMessage): current_scratchpad = AgentScratchpadUnit( diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index a904f3e641..e8b05373ab 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -32,9 +32,6 @@ class CotChatAgentRunner(CotAgentRunner): # organize system prompt system_message = self._organize_system_prompt() - # organize historic prompt messages - historic_messages = self._historic_prompt_messages - # organize current assistant messages agent_scratchpad = self._agent_scratchpad if not agent_scratchpad: @@ -57,6 +54,13 @@ class CotChatAgentRunner(CotAgentRunner): query_messages = UserPromptMessage(content=self._query) if assistant_messages: + # organize historic prompt messages + historic_messages = self._organize_historic_prompt_messages([ + system_message, + query_messages, + *assistant_messages, + UserPromptMessage(content='continue') + ]) messages = [ system_message, *historic_messages, @@ -65,6 +69,8 @@ class CotChatAgentRunner(CotAgentRunner): UserPromptMessage(content='continue') ] else: + # organize historic prompt messages + historic_messages = self._organize_historic_prompt_messages([system_message, query_messages]) messages = [system_message, *historic_messages, query_messages] # join all messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 3f0298d5a3..9e6eb54f4f 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -19,11 +19,11 @@ class CotCompletionAgentRunner(CotAgentRunner): return system_prompt - def _organize_historic_prompt(self) -> str: + def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: """ Organize historic prompt """ - historic_prompt_messages = self._historic_prompt_messages + historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages) historic_prompt = "" for message in historic_prompt_messages: diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 5284faa02e..5274224de5 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -8,7 +8,7 @@ class AgentToolEntity(BaseModel): """ Agent Tool Entity. """ - provider_type: Literal["builtin", "api"] + provider_type: Literal["builtin", "api", "workflow"] provider_id: str tool_name: str tool_parameters: dict[str, Any] = {} diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index a9b3a80073..d416a319a4 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -17,6 +17,7 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine from models.model import Message @@ -24,21 +25,18 @@ from models.model import Message logger = logging.getLogger(__name__) class FunctionCallAgentRunner(BaseAgentRunner): + def run(self, message: Message, query: str, **kwargs: Any ) -> Generator[LLMResultChunk, None, None]: """ Run FunctionCall agent application """ + self.query = query app_generate_entity = self.application_generate_entity app_config = self.app_config - prompt_template = app_config.prompt_template.simple_prompt_template or '' - prompt_messages = self.history_prompt_messages - prompt_messages = self._init_system_message(prompt_template, prompt_messages) - prompt_messages = self._organize_user_query(query, prompt_messages) - # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() @@ -81,6 +79,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): ) # recalc llm max tokens + prompt_messages = self._organize_prompt_messages() self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( @@ -203,7 +202,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): else: assistant_message.content = response - prompt_messages.append(assistant_message) + self._current_thoughts.append(assistant_message) # save thought self.save_agent_thought( @@ -265,12 +264,14 @@ class FunctionCallAgentRunner(BaseAgentRunner): } tool_responses.append(tool_response) - prompt_messages = self._organize_assistant_message( - tool_call_id=tool_call_id, - tool_call_name=tool_call_name, - tool_response=tool_response['tool_response'], - prompt_messages=prompt_messages, - ) + if tool_response['tool_response'] is not None: + self._current_thoughts.append( + ToolPromptMessage( + content=tool_response['tool_response'], + tool_call_id=tool_call_id, + name=tool_call_name, + ) + ) if len(tool_responses) > 0: # save agent thought @@ -300,8 +301,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): iteration_step += 1 - prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) - self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( @@ -393,24 +392,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): return prompt_messages - def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, - prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: - """ - Organize assistant message - """ - prompt_messages = deepcopy(prompt_messages) - - if tool_response is not None: - prompt_messages.append( - ToolPromptMessage( - content=tool_response, - tool_call_id=tool_call_id, - name=tool_call_name, - ) - ) - - return prompt_messages - def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ As for now, gpt supports both fc and vision at the first iteration. @@ -428,4 +409,26 @@ class FunctionCallAgentRunner(BaseAgentRunner): for content in prompt_message.content ]) - return prompt_messages \ No newline at end of file + return prompt_messages + + def _organize_prompt_messages(self): + prompt_template = self.app_config.prompt_template.simple_prompt_template or '' + self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) + query_prompt_messages = self._organize_user_query(self.query, []) + + self.history_prompt_messages = AgentHistoryPromptTransform( + model_config=self.model_config, + prompt_messages=[*query_prompt_messages, *self._current_thoughts], + history_messages=self.history_prompt_messages, + memory=self.memory + ).get_prompt() + + prompt_messages = [ + *self.history_prompt_messages, + *query_prompt_messages, + *self._current_thoughts + ] + if len(self._current_thoughts) != 0: + # clear messages after the first iteration + prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) + return prompt_messages diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 91ac41143b..5a445e9e59 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -9,7 +9,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: @classmethod - def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None]) -> \ + def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \ Generator[Union[str, AgentScratchpadUnit.Action], None, None]: def parse_action(json_str): try: @@ -58,6 +58,8 @@ class CotAgentOutputParser: thought_idx = 0 for response in llm_response: + if response.delta.usage: + usage_dict['usage'] = response.delta.usage response = response.delta.message.content if not isinstance(response, str): continue diff --git a/api/core/tools/prompt/template.py b/api/core/agent/prompt/template.py similarity index 100% rename from api/core/tools/prompt/template.py rename to api/core/agent/prompt/template.py diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 66d4a3275b..1ca8b1e3b8 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -11,7 +11,7 @@ class SensitiveWordAvoidanceConfigManager: if not sensitive_word_avoidance_dict: return None - if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']: + if sensitive_word_avoidance_dict.get('enabled'): return SensitiveWordAvoidanceEntity( type=sensitive_word_avoidance_dict.get('type'), config=sensitive_word_avoidance_dict.get('config'), diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index a48316728b..f271aeed0c 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -1,7 +1,7 @@ from typing import Optional from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity -from core.tools.prompt.template import REACT_PROMPT_TEMPLATES +from core.agent.prompt.template import REACT_PROMPT_TEMPLATES class AgentConfigManager: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 101e25d582..d6b6d89416 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -239,4 +239,4 @@ class WorkflowUIBasedAppConfig(AppConfig): """ Workflow UI Based App Config Entity. """ - workflow_id: str + workflow_id: str \ No newline at end of file diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index acffa6e9e7..2049b573cd 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -14,7 +14,7 @@ class FileUploadConfigManager: """ file_upload_dict = config.get('file_upload') if file_upload_dict: - if 'image' in file_upload_dict and file_upload_dict['image']: + if file_upload_dict.get('image'): if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: image_config = { 'number_limits': file_upload_dict['image']['number_limits'], diff --git a/api/core/app/app_config/features/more_like_this/manager.py b/api/core/app/app_config/features/more_like_this/manager.py index ec2a9a6796..2ba99a5c40 100644 --- a/api/core/app/app_config/features/more_like_this/manager.py +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -9,7 +9,7 @@ class MoreLikeThisConfigManager: more_like_this = False more_like_this_dict = config.get('more_like_this') if more_like_this_dict: - if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']: + if more_like_this_dict.get('enabled'): more_like_this = True return more_like_this diff --git a/api/core/app/app_config/features/retrieval_resource/manager.py b/api/core/app/app_config/features/retrieval_resource/manager.py index 0694cb954e..fca58e12e8 100644 --- a/api/core/app/app_config/features/retrieval_resource/manager.py +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -4,7 +4,7 @@ class RetrievalResourceConfigManager: show_retrieve_source = False retriever_resource_dict = config.get('retriever_resource') if retriever_resource_dict: - if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']: + if retriever_resource_dict.get('enabled'): show_retrieve_source = True return show_retrieve_source diff --git a/api/core/app/app_config/features/speech_to_text/manager.py b/api/core/app/app_config/features/speech_to_text/manager.py index b98699bfff..88b4be25d3 100644 --- a/api/core/app/app_config/features/speech_to_text/manager.py +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -9,7 +9,7 @@ class SpeechToTextConfigManager: speech_to_text = False speech_to_text_dict = config.get('speech_to_text') if speech_to_text_dict: - if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']: + if speech_to_text_dict.get('enabled'): speech_to_text = True return speech_to_text diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py index 5aacd3b32d..c6cab01220 100644 --- a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -9,7 +9,7 @@ class SuggestedQuestionsAfterAnswerConfigManager: suggested_questions_after_answer = False suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer') if suggested_questions_after_answer_dict: - if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']: + if suggested_questions_after_answer_dict.get('enabled'): suggested_questions_after_answer = True return suggested_questions_after_answer diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py index 1ff31034ad..b516fa46ab 100644 --- a/api/core/app/app_config/features/text_to_speech/manager.py +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -12,7 +12,7 @@ class TextToSpeechConfigManager: text_to_speech = False text_to_speech_dict = config.get('text_to_speech') if text_to_speech_dict: - if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']: + if text_to_speech_dict.get('enabled'): text_to_speech = TextToSpeechEntity( enabled=text_to_speech_dict.get('enabled'), voice=text_to_speech_dict.get('voice'), diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index e5cf585f82..3b1ee3578d 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -66,7 +66,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) # parse files - files = args['files'] if 'files' in args and args['files'] else [] + files = args['files'] if args.get('files') else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: @@ -98,6 +98,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): extras=extras ) + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + invoke_from=invoke_from, + application_generate_entity=application_generate_entity, + conversation=conversation, + stream=stream + ) + + def single_iteration_generate(self, app_model: App, + workflow: Workflow, + node_id: str, + user: Account, + args: dict, + stream: bool = True) \ + -> Union[dict, Generator[dict, None, None]]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not node_id: + raise ValueError('node_id is required') + + if args.get('inputs') is None: + raise ValueError('inputs is required') + + extras = { + "auto_generate_conversation_name": False + } + + # get conversation + conversation = None + if args.get('conversation_id'): + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow + ) + + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + conversation_id=conversation.id if conversation else None, + inputs={}, + query='', + files=[], + user_id=user.id, + stream=stream, + invoke_from=InvokeFrom.DEBUGGER, + extras=extras, + single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( + node_id=node_id, + inputs=args['inputs'] + ) + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + conversation=conversation, + stream=stream + ) + + def _generate(self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + application_generate_entity: AdvancedChatAppGenerateEntity, + conversation: Conversation = None, + stream: bool = True) \ + -> Union[dict, Generator[dict, None, None]]: is_first_conversation = False if not conversation: is_first_conversation = True @@ -167,18 +251,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): """ with flask_app.app_context(): try: - # get conversation and message - conversation = self._get_conversation(conversation_id) - message = self._get_message(message_id) - - # chatbot app runner = AdvancedChatAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) + if application_generate_entity.single_iteration_run: + single_iteration_run = application_generate_entity.single_iteration_run + runner.single_iteration_run( + app_id=application_generate_entity.app_config.app_id, + workflow_id=application_generate_entity.app_config.workflow_id, + queue_manager=queue_manager, + inputs=single_iteration_run.inputs, + node_id=single_iteration_run.node_id, + user_id=application_generate_entity.user_id + ) + else: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = AdvancedChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) except GenerateTaskStoppedException: pass except InvokeAuthorizationError: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d858dcac12..de3632894d 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -102,6 +102,7 @@ class AdvancedChatAppRunner(AppRunner): user_from=UserFrom.ACCOUNT if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, + invoke_from=application_generate_entity.invoke_from, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, @@ -109,6 +110,35 @@ class AdvancedChatAppRunner(AppRunner): SystemVariable.CONVERSATION_ID: conversation.id, SystemVariable.USER_ID: user_id }, + callbacks=workflow_callbacks, + call_depth=application_generate_entity.call_depth + ) + + def single_iteration_run(self, app_id: str, workflow_id: str, + queue_manager: AppQueueManager, + inputs: dict, node_id: str, user_id: str) -> None: + """ + Single iteration run + """ + app_record: App = db.session.query(App).filter(App.id == app_id).first() + if not app_record: + raise ValueError("App not found") + + workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + workflow_callbacks = [WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] + + workflow_engine_manager = WorkflowEngineManager() + workflow_engine_manager.single_step_run_iteration_workflow_node( + workflow=workflow, + node_id=node_id, + user_id=user_id, + user_inputs=inputs, callbacks=workflow_callbacks ) diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 80e8e22e88..08069332ba 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -8,6 +8,8 @@ from core.app.entities.task_entities import ( ChatbotAppStreamResponse, ErrorStreamResponse, MessageEndStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, PingStreamResponse, ) @@ -111,6 +113,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) + elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) else: response_chunk.update(sub_stream_response.to_dict()) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 67d8fe5cb1..7c70afc2ae 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -12,6 +12,9 @@ from core.app.entities.queue_entities import ( QueueAdvancedChatMessageEndEvent, QueueAnnotationReplyEvent, QueueErrorEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueMessageReplaceEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, @@ -64,6 +67,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _workflow: Workflow _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariable, Any] + _iteration_nested_relations: dict[str, list[str]] def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, workflow: Workflow, @@ -103,6 +107,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc usage=LLMUsage.empty_usage() ) + self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) self._stream_generate_routes = self._get_stream_generate_routes() self._conversation_name_generate_thread = None @@ -204,6 +209,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # search stream_generate_routes if node id is answer start at node if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] + # reset current route position to 0 + self._task_state.current_stream_generate_state.current_route_position = 0 # generate stream outputs when node started yield from self._generate_stream_outputs_when_node_started() @@ -225,6 +232,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + + if isinstance(event, QueueNodeFailedEvent): + yield from self._handle_iteration_exception( + task_id=self._application_generate_entity.task_id, + error=f'Child node failed: {event.error}' + ) + elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): + if isinstance(event, QueueIterationNextEvent): + # clear ran node execution infos of current iteration + iteration_relations = self._iteration_nested_relations.get(event.node_id) + if iteration_relations: + for node_id in iteration_relations: + self._task_state.ran_node_execution_infos.pop(node_id, None) + + yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) + self._handle_iteration_operation(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): workflow_run = self._handle_workflow_finished(event) if workflow_run: @@ -263,10 +286,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) - # elif isinstance(event, QueueMessageFileEvent): - # response = self._message_file_to_stream_response(event) - # if response: - # yield response elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -342,7 +361,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc id=self._message.id, **extras ) - + def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]: """ Get stream generate routes. @@ -372,7 +391,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ) return stream_generate_routes - + def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \ -> list[str]: """ @@ -391,6 +410,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ingoing_edges.append(edge) if not ingoing_edges: + # check if it's the first node in the iteration + target_node = next((node for node in nodes if node.get('id') == target_node_id), None) + if not target_node: + return [] + + node_iteration_id = target_node.get('data', {}).get('iteration_id') + # get iteration start node id + for node in nodes: + if node.get('id') == node_iteration_id: + if node.get('data', {}).get('start_node_id') == target_node_id: + return [target_node_id] + return [] start_node_ids = [] @@ -401,14 +432,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc continue node_type = source_node.get('data', {}).get('type') + node_iteration_id = source_node.get('data', {}).get('iteration_id') + iteration_start_node_id = None + if node_iteration_id: + iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) + iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') + if node_type in [ NodeType.ANSWER.value, NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value + NodeType.QUESTION_CLASSIFIER.value, + NodeType.ITERATION.value, + NodeType.LOOP.value ]: start_node_id = target_node_id start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value: + elif node_type == NodeType.START.value or \ + node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): start_node_id = source_node_id start_node_ids.append(start_node_id) else: @@ -417,7 +457,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc start_node_ids.extend(sub_start_node_ids) return start_node_ids + + def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: + """ + Get iteration nested relations. + :param graph: graph + :return: + """ + nodes = graph.get('nodes') + iteration_ids = [node.get('id') for node in nodes + if node.get('data', {}).get('type') in [ + NodeType.ITERATION.value, + NodeType.LOOP.value, + ]] + + return { + iteration_id: [ + node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id + ] for iteration_id in iteration_ids + } + def _generate_stream_outputs_when_node_started(self) -> Generator: """ Generate stream outputs. @@ -425,7 +485,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ if self._task_state.current_stream_generate_state: route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position:] + self._task_state.current_stream_generate_state.current_route_position: + ] for route_chunk in route_chunks: if route_chunk.type == 'text': @@ -458,13 +519,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc route_chunks = self._task_state.current_stream_generate_state.generate_route[ self._task_state.current_stream_generate_state.current_route_position:] - + for route_chunk in route_chunks: if route_chunk.type == 'text': route_chunk = cast(TextGenerateRouteChunk, route_chunk) self._task_state.answer += route_chunk.text yield self._message_to_stream_response(route_chunk.text, self._message.id) else: + value = None route_chunk = cast(VarGenerateRouteChunk, route_chunk) value_selector = route_chunk.value_selector if not value_selector: @@ -476,6 +538,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if route_chunk_node_id == 'sys': # system variable value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1])) + elif route_chunk_node_id in self._iteration_nested_relations: + # it's a iteration variable + if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: + continue + iteration_state = self._iteration_state.current_iterations[route_chunk_node_id] + iterator = iteration_state.inputs + if not iterator: + continue + iterator_selector = iterator.get('iterator_selector', []) + if value_selector[1] == 'index': + value = iteration_state.current_index + elif value_selector[1] == 'item': + value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len( + iterator_selector) else None else: # check chunk node id is before current node id or equal to current node id if route_chunk_node_id not in self._task_state.ran_node_execution_infos: @@ -505,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc else: value = value.get(key) - if value: + if value is not None: text = '' if isinstance(value, str | int | float): text = str(value) diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index fef719a086..78fe077e6b 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -1,8 +1,11 @@ -from typing import Optional +from typing import Any, Optional from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import ( AppQueueEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, @@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ), PublishFrom.APPLICATION_MANAGER ) + def on_workflow_iteration_started(self, + node_id: str, + node_type: NodeType, + node_run_index: int = 1, + node_data: Optional[BaseNodeData] = None, + inputs: dict = None, + predecessor_node_id: Optional[str] = None, + metadata: Optional[dict] = None) -> None: + """ + Publish iteration started + """ + self._queue_manager.publish( + QueueIterationStartEvent( + node_id=node_id, + node_type=node_type, + node_run_index=node_run_index, + node_data=node_data, + inputs=inputs, + predecessor_node_id=predecessor_node_id, + metadata=metadata + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_iteration_next(self, node_id: str, + node_type: NodeType, + index: int, + node_run_index: int, + output: Optional[Any]) -> None: + """ + Publish iteration next + """ + self._queue_manager._publish( + QueueIterationNextEvent( + node_id=node_id, + node_type=node_type, + index=index, + node_run_index=node_run_index, + output=output + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_iteration_completed(self, node_id: str, + node_type: NodeType, + node_run_index: int, + outputs: dict) -> None: + """ + Publish iteration completed + """ + self._queue_manager._publish( + QueueIterationCompletedEvent( + node_id=node_id, + node_type=node_type, + node_run_index=node_run_index, + outputs=outputs + ), + PublishFrom.APPLICATION_MANAGER + ) + def on_event(self, event: AppQueueEvent) -> None: """ Publish event diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 847d314409..fb4c28a855 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -83,7 +83,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): ) # parse files - files = args['files'] if 'files' in args and args['files'] else [] + files = args['files'] if args.get('files') else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: @@ -115,7 +115,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): user_id=user.id, stream=stream, invoke_from=invoke_from, - extras=extras + extras=extras, + call_depth=0 ) # init generate records diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 9d88c834e6..20ae6ff676 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -13,7 +13,9 @@ class BaseAppGenerator: for variable_config in variables: variable = variable_config.variable - if variable not in user_inputs or not user_inputs[variable]: + if (variable not in user_inputs + or user_inputs[variable] is None + or (isinstance(user_inputs[variable], str) and user_inputs[variable] == '')): if variable_config.required: raise ValueError(f"{variable} is required in input form") else: @@ -22,7 +24,7 @@ class BaseAppGenerator: value = user_inputs[variable] - if value: + if value is not None: if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str): raise ValueError(f"{variable} in input form must be a string") elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str): @@ -44,7 +46,7 @@ class BaseAppGenerator: if value and isinstance(value, str): filtered_inputs[variable] = value.replace('\x00', '') else: - filtered_inputs[variable] = value if value else None + filtered_inputs[variable] = value if value is not None else None return filtered_inputs diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index f1f426b27e..545463c8bd 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,6 @@ import time from collections.abc import Generator -from typing import Optional, Union, cast +from typing import Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -16,11 +16,11 @@ from core.app.features.hosting_moderation.hosting_moderation import HostingModer from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig @@ -45,8 +45,11 @@ class AppRunner: :param query: query :return: """ - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + # Invoke model + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model + ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) @@ -73,9 +76,7 @@ class AppRunner: query=query ) - prompt_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, + prompt_tokens = model_instance.get_llm_num_tokens( prompt_messages ) @@ -89,8 +90,10 @@ class AppRunner: def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model + ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) @@ -107,9 +110,7 @@ class AppRunner: if max_tokens is None: max_tokens = 0 - prompt_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, + prompt_tokens = model_instance.get_llm_num_tokens( prompt_messages ) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index e67901cca8..4ad26e8506 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -80,7 +80,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): ) # parse files - files = args['files'] if 'files' in args and args['files'] else [] + files = args['files'] if args.get('files') else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 5f93afcad7..31ce4d70b6 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -75,7 +75,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) # parse files - files = args['files'] if 'files' in args and args['files'] else [] + files = args['files'] if args.get('files') else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index a9b038ab51..c4324978d8 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -34,7 +34,8 @@ class WorkflowAppGenerator(BaseAppGenerator): user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, - stream: bool = True) \ + stream: bool = True, + call_depth: int = 0) \ -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -49,7 +50,7 @@ class WorkflowAppGenerator(BaseAppGenerator): inputs = args['inputs'] # parse files - files = args['files'] if 'files' in args and args['files'] else [] + files = args['files'] if args.get('files') else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: @@ -75,9 +76,38 @@ class WorkflowAppGenerator(BaseAppGenerator): files=file_objs, user_id=user.id, stream=stream, - invoke_from=invoke_from + invoke_from=invoke_from, + call_depth=call_depth ) + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + stream=stream, + call_depth=call_depth + ) + + def _generate(self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + application_generate_entity: WorkflowAppGenerateEntity, + invoke_from: InvokeFrom, + stream: bool = True, + call_depth: int = 0) \ + -> Union[dict, Generator[dict, None, None]]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param application_generate_entity: application generate entity + :param invoke_from: invoke from source + :param stream: is stream + """ # init queue manager queue_manager = WorkflowAppQueueManager( task_id=application_generate_entity.task_id, @@ -109,6 +139,64 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from=invoke_from ) + def single_iteration_generate(self, app_model: App, + workflow: Workflow, + node_id: str, + user: Account, + args: dict, + stream: bool = True) \ + -> Union[dict, Generator[dict, None, None]]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not node_id: + raise ValueError('node_id is required') + + if args.get('inputs') is None: + raise ValueError('inputs is required') + + extras = { + "auto_generate_conversation_name": False + } + + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow + ) + + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs={}, + files=[], + user_id=user.id, + stream=stream, + invoke_from=InvokeFrom.DEBUGGER, + extras=extras, + single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( + node_id=node_id, + inputs=args['inputs'] + ) + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + stream=stream + ) + def _generate_worker(self, flask_app: Flask, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: @@ -123,10 +211,21 @@ class WorkflowAppGenerator(BaseAppGenerator): try: # workflow app runner = WorkflowAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager - ) + if application_generate_entity.single_iteration_run: + single_iteration_run = application_generate_entity.single_iteration_run + runner.single_iteration_run( + app_id=application_generate_entity.app_config.app_id, + workflow_id=application_generate_entity.app_config.workflow_id, + queue_manager=queue_manager, + inputs=single_iteration_run.inputs, + node_id=single_iteration_run.node_id, + user_id=application_generate_entity.user_id + ) + else: + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager + ) except GenerateTaskStoppedException: pass except InvokeAuthorizationError: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 9d854afe35..050319e552 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -73,11 +73,44 @@ class WorkflowAppRunner: user_from=UserFrom.ACCOUNT if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, + invoke_from=application_generate_entity.invoke_from, user_inputs=inputs, system_inputs={ SystemVariable.FILES: files, SystemVariable.USER_ID: user_id }, + callbacks=workflow_callbacks, + call_depth=application_generate_entity.call_depth + ) + + def single_iteration_run(self, app_id: str, workflow_id: str, + queue_manager: AppQueueManager, + inputs: dict, node_id: str, user_id: str) -> None: + """ + Single iteration run + """ + app_record: App = db.session.query(App).filter(App.id == app_id).first() + if not app_record: + raise ValueError("App not found") + + if not app_record.workflow_id: + raise ValueError("Workflow not initialized") + + workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + workflow_callbacks = [WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] + + workflow_engine_manager = WorkflowEngineManager() + workflow_engine_manager.single_step_run_iteration_workflow_node( + workflow=workflow, + node_id=node_id, + user_id=user_id, + user_inputs=inputs, callbacks=workflow_callbacks ) diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index d907b82c99..88bde58ba0 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -5,6 +5,8 @@ from typing import cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( ErrorStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, PingStreamResponse, WorkflowAppBlockingResponse, WorkflowAppStreamResponse, @@ -68,4 +70,24 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): :param stream_response: stream response :return: """ - return cls.convert_stream_full_response(stream_response) + for chunk in stream_response: + chunk = cast(WorkflowAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield 'ping' + continue + + response_chunk = { + 'event': sub_stream_response.event.value, + 'workflow_run_id': chunk.workflow_run_id, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index a7061a77bb..8d961e0993 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -9,6 +9,9 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.entities.queue_entities import ( QueueErrorEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueMessageReplaceEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, @@ -58,6 +61,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariable, Any] + _iteration_nested_relations: dict[str, list[str]] def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, workflow: Workflow, @@ -85,8 +89,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa SystemVariable.USER_ID: user_id } - self._task_state = WorkflowTaskState() + self._task_state = WorkflowTaskState( + iteration_nested_node_ids=[] + ) self._stream_generate_nodes = self._get_stream_generate_nodes() + self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -191,6 +198,22 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + + if isinstance(event, QueueNodeFailedEvent): + yield from self._handle_iteration_exception( + task_id=self._application_generate_entity.task_id, + error=f'Child node failed: {event.error}' + ) + elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): + if isinstance(event, QueueIterationNextEvent): + # clear ran node execution infos of current iteration + iteration_relations = self._iteration_nested_relations.get(event.node_id) + if iteration_relations: + for node_id in iteration_relations: + self._task_state.ran_node_execution_infos.pop(node_id, None) + + yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) + self._handle_iteration_operation(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): workflow_run = self._handle_workflow_finished(event) @@ -331,13 +354,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa continue node_type = source_node.get('data', {}).get('type') + node_iteration_id = source_node.get('data', {}).get('iteration_id') + iteration_start_node_id = None + if node_iteration_id: + iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) + iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') + if node_type in [ NodeType.IF_ELSE.value, NodeType.QUESTION_CLASSIFIER.value ]: start_node_id = target_node_id start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value: + elif node_type == NodeType.START.value or \ + node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): start_node_id = source_node_id start_node_ids.append(start_node_id) else: @@ -411,3 +441,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa return False return True + + def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: + """ + Get iteration nested relations. + :param graph: graph + :return: + """ + nodes = graph.get('nodes') + + iteration_ids = [node.get('id') for node in nodes + if node.get('data', {}).get('type') in [ + NodeType.ITERATION.value, + NodeType.LOOP.value, + ]] + + return { + iteration_id: [ + node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id + ] for iteration_id in iteration_ids + } + \ No newline at end of file diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 2048abe464..e423a40bcb 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -1,8 +1,11 @@ -from typing import Optional +from typing import Any, Optional from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import ( AppQueueEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, @@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ), PublishFrom.APPLICATION_MANAGER ) + def on_workflow_iteration_started(self, + node_id: str, + node_type: NodeType, + node_run_index: int = 1, + node_data: Optional[BaseNodeData] = None, + inputs: dict = None, + predecessor_node_id: Optional[str] = None, + metadata: Optional[dict] = None) -> None: + """ + Publish iteration started + """ + self._queue_manager.publish( + QueueIterationStartEvent( + node_id=node_id, + node_type=node_type, + node_run_index=node_run_index, + node_data=node_data, + inputs=inputs, + predecessor_node_id=predecessor_node_id, + metadata=metadata + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_iteration_next(self, node_id: str, + node_type: NodeType, + index: int, + node_run_index: int, + output: Optional[Any]) -> None: + """ + Publish iteration next + """ + self._queue_manager.publish( + QueueIterationNextEvent( + node_id=node_id, + node_type=node_type, + index=index, + node_run_index=node_run_index, + output=output + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_iteration_completed(self, node_id: str, + node_type: NodeType, + node_run_index: int, + outputs: dict) -> None: + """ + Publish iteration completed + """ + self._queue_manager.publish( + QueueIterationCompletedEvent( + node_id=node_id, + node_type=node_type, + node_run_index=node_run_index, + outputs=outputs + ), + PublishFrom.APPLICATION_MANAGER + ) + def on_event(self, event: AppQueueEvent) -> None: """ Publish event diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index 4627c21c7a..f617c671e9 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -102,6 +102,39 @@ class WorkflowLoggingCallback(BaseWorkflowCallback): self.print_text(text, color="pink", end="") + def on_workflow_iteration_started(self, + node_id: str, + node_type: NodeType, + node_run_index: int = 1, + node_data: Optional[BaseNodeData] = None, + inputs: dict = None, + predecessor_node_id: Optional[str] = None, + metadata: Optional[dict] = None) -> None: + """ + Publish iteration started + """ + self.print_text("\n[on_workflow_iteration_started]", color='blue') + self.print_text(f"Node ID: {node_id}", color='blue') + + def on_workflow_iteration_next(self, node_id: str, + node_type: NodeType, + index: int, + node_run_index: int, + output: Optional[dict]) -> None: + """ + Publish iteration next + """ + self.print_text("\n[on_workflow_iteration_next]", color='blue') + + def on_workflow_iteration_completed(self, node_id: str, + node_type: NodeType, + node_run_index: int, + outputs: dict) -> None: + """ + Publish iteration completed + """ + self.print_text("\n[on_workflow_iteration_completed]", color='blue') + def on_event(self, event: AppQueueEvent) -> None: """ Publish event diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 09c62c802c..cc63fa4684 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -80,6 +80,9 @@ class AppGenerateEntity(BaseModel): stream: bool invoke_from: InvokeFrom + # invoke call depth + call_depth: int = 0 + # extra parameters, like: auto_generate_conversation_name extras: dict[str, Any] = {} @@ -126,6 +129,14 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): conversation_id: Optional[str] = None query: Optional[str] = None + class SingleIterationRunEntity(BaseModel): + """ + Single Iteration Run Entity. + """ + node_id: str + inputs: dict + + single_iteration_run: Optional[SingleIterationRunEntity] = None class WorkflowAppGenerateEntity(AppGenerateEntity): """ @@ -133,3 +144,12 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): """ # app config app_config: WorkflowUIBasedAppConfig + + class SingleIterationRunEntity(BaseModel): + """ + Single Iteration Run Entity. + """ + node_id: str + inputs: dict + + single_iteration_run: Optional[SingleIterationRunEntity] = None \ No newline at end of file diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index bf174e30e1..47fa2ac19d 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, validator from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -21,6 +21,9 @@ class QueueEvent(Enum): WORKFLOW_STARTED = "workflow_started" WORKFLOW_SUCCEEDED = "workflow_succeeded" WORKFLOW_FAILED = "workflow_failed" + ITERATION_START = "iteration_start" + ITERATION_NEXT = "iteration_next" + ITERATION_COMPLETED = "iteration_completed" NODE_STARTED = "node_started" NODE_SUCCEEDED = "node_succeeded" NODE_FAILED = "node_failed" @@ -47,6 +50,55 @@ class QueueLLMChunkEvent(AppQueueEvent): event = QueueEvent.LLM_CHUNK chunk: LLMResultChunk +class QueueIterationStartEvent(AppQueueEvent): + """ + QueueIterationStartEvent entity + """ + event = QueueEvent.ITERATION_START + node_id: str + node_type: NodeType + node_data: BaseNodeData + + node_run_index: int + inputs: dict = None + predecessor_node_id: Optional[str] = None + metadata: Optional[dict] = None + +class QueueIterationNextEvent(AppQueueEvent): + """ + QueueIterationNextEvent entity + """ + event = QueueEvent.ITERATION_NEXT + + index: int + node_id: str + node_type: NodeType + + node_run_index: int + output: Optional[Any] # output for the current iteration + + @validator('output', pre=True, always=True) + def set_output(cls, v): + """ + Set output + """ + if v is None: + return None + if isinstance(v, int | float | str | bool | dict | list): + return v + raise ValueError('output must be a valid type') + +class QueueIterationCompletedEvent(AppQueueEvent): + """ + QueueIterationCompletedEvent entity + """ + event = QueueEvent.ITERATION_COMPLETED + + node_id: str + node_type: NodeType + + node_run_index: int + outputs: dict class QueueTextChunkEvent(AppQueueEvent): """ diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 4994efe2e9..5956bc35fa 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,12 +1,14 @@ from enum import Enum -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.answer.entities import GenerateRouteChunk +from models.workflow import WorkflowNodeExecutionStatus class WorkflowStreamGenerateNodes(BaseModel): @@ -65,6 +67,7 @@ class WorkflowTaskState(TaskState): current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None + iteration_nested_node_ids: list[str] = None class AdvancedChatTaskState(WorkflowTaskState): """ @@ -91,6 +94,9 @@ class StreamEvent(Enum): WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" + ITERATION_STARTED = "iteration_started" + ITERATION_NEXT = "iteration_next" + ITERATION_COMPLETED = "iteration_completed" TEXT_CHUNK = "text_chunk" TEXT_REPLACE = "text_replace" @@ -246,6 +252,24 @@ class NodeStartStreamResponse(StreamResponse): workflow_run_id: str data: Data + def to_ignore_detail_dict(self): + return { + "event": self.event.value, + "task_id": self.task_id, + "workflow_run_id": self.workflow_run_id, + "data": { + "id": self.data.id, + "node_id": self.data.node_id, + "node_type": self.data.node_type, + "title": self.data.title, + "index": self.data.index, + "predecessor_node_id": self.data.predecessor_node_id, + "inputs": None, + "created_at": self.data.created_at, + "extras": {} + } + } + class NodeFinishStreamResponse(StreamResponse): """ @@ -276,6 +300,99 @@ class NodeFinishStreamResponse(StreamResponse): workflow_run_id: str data: Data + def to_ignore_detail_dict(self): + return { + "event": self.event.value, + "task_id": self.task_id, + "workflow_run_id": self.workflow_run_id, + "data": { + "id": self.data.id, + "node_id": self.data.node_id, + "node_type": self.data.node_type, + "title": self.data.title, + "index": self.data.index, + "predecessor_node_id": self.data.predecessor_node_id, + "inputs": None, + "process_data": None, + "outputs": None, + "status": self.data.status, + "error": None, + "elapsed_time": self.data.elapsed_time, + "execution_metadata": None, + "created_at": self.data.created_at, + "finished_at": self.data.finished_at, + "files": [] + } + } + +class IterationNodeStartStreamResponse(StreamResponse): + """ + NodeStartStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + node_id: str + node_type: str + title: str + created_at: int + extras: dict = {} + metadata: dict = {} + inputs: dict = {} + + event: StreamEvent = StreamEvent.ITERATION_STARTED + workflow_run_id: str + data: Data + +class IterationNodeNextStreamResponse(StreamResponse): + """ + NodeStartStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + node_id: str + node_type: str + title: str + index: int + created_at: int + pre_iteration_output: Optional[Any] + extras: dict = {} + + event: StreamEvent = StreamEvent.ITERATION_NEXT + workflow_run_id: str + data: Data + +class IterationNodeCompletedStreamResponse(StreamResponse): + """ + NodeStartStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + node_id: str + node_type: str + title: str + outputs: Optional[dict] + created_at: int + extras: dict = None + inputs: dict = None + status: WorkflowNodeExecutionStatus + error: Optional[str] + elapsed_time: float + total_tokens: int + finished_at: int + steps: int + + event: StreamEvent = StreamEvent.ITERATION_COMPLETED + workflow_run_id: str + data: Data class TextChunkStreamResponse(StreamResponse): """ @@ -411,3 +528,23 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): workflow_run_id: str data: Data + +class WorkflowIterationState(BaseModel): + """ + WorkflowIterationState entity + """ + class Data(BaseModel): + """ + Data entity + """ + parent_iteration_id: Optional[str] = None + iteration_id: str + current_index: int + iteration_steps_boundary: list[int] = None + node_execution_id: str + started_at: float + inputs: dict = None + total_tokens: int = 0 + node_data: BaseNodeData + + current_iterations: dict[str, Data] = None \ No newline at end of file diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index a7dbb4754c..f71470edb2 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -37,6 +37,7 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -317,29 +318,30 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan """ model_config = self._model_config model = model_config.model - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model + ) # calculate num tokens prompt_tokens = 0 if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: - prompt_tokens = model_type_instance.get_num_tokens( - model, - model_config.credentials, + prompt_tokens = model_instance.get_llm_num_tokens( self._task_state.llm_result.prompt_messages ) completion_tokens = 0 if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: - completion_tokens = model_type_instance.get_num_tokens( - model, - model_config.credentials, + completion_tokens = model_instance.get_llm_num_tokens( [self._task_state.llm_result.message] ) credentials = model_config.credentials # transform usage + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) self._task_state.llm_result.usage = model_type_instance._calc_response_usage( model, credentials, diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 48ff34fef9..978a318279 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,9 +1,9 @@ import json import time from datetime import datetime, timezone -from typing import Any, Optional, Union, cast +from typing import Optional, Union, cast -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( QueueNodeFailedEvent, QueueNodeStartedEvent, @@ -13,18 +13,17 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, NodeExecutionInfo, NodeFinishStreamResponse, NodeStartStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, - WorkflowTaskState, ) +from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage from core.file.file_obj import FileVar from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.tool_manager import ToolManager -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db @@ -42,13 +41,7 @@ from models.workflow import ( ) -class WorkflowCycleManage: - _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] - _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariable, Any] - +class WorkflowCycleManage(WorkflowIterationCycleManage): def _init_workflow_run(self, workflow: Workflow, triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser], @@ -237,6 +230,7 @@ class WorkflowCycleManage: inputs: Optional[dict] = None, process_data: Optional[dict] = None, outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None ) -> WorkflowNodeExecution: """ Workflow node execution failed @@ -255,6 +249,8 @@ class WorkflowCycleManage: workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(process_data) if process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ + if execution_metadata else None db.session.commit() db.session.refresh(workflow_node_execution) @@ -444,6 +440,23 @@ class WorkflowCycleManage: current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() + + execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None + + if self._iteration_state and self._iteration_state.current_iterations: + if not execution_metadata: + execution_metadata = {} + current_iteration_data = None + for iteration_node_id in self._iteration_state.current_iterations: + data = self._iteration_state.current_iterations[iteration_node_id] + if data.parent_iteration_id == None: + current_iteration_data = data + break + + if current_iteration_data: + execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id + execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index + if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( workflow_node_execution=workflow_node_execution, @@ -451,12 +464,18 @@ class WorkflowCycleManage: inputs=event.inputs, process_data=event.process_data, outputs=event.outputs, - execution_metadata=event.execution_metadata + execution_metadata=execution_metadata ) - if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): self._task_state.total_tokens += ( - int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) + int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) + + if self._iteration_state: + for iteration_node_id in self._iteration_state.current_iterations: + data = self._iteration_state.current_iterations[iteration_node_id] + if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) if workflow_node_execution.node_type == NodeType.LLM.value: outputs = workflow_node_execution.outputs_dict @@ -469,7 +488,8 @@ class WorkflowCycleManage: error=event.error, inputs=event.inputs, process_data=event.process_data, - outputs=event.outputs + outputs=event.outputs, + execution_metadata=execution_metadata ) db.session.close() diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py new file mode 100644 index 0000000000..545f31fddf --- /dev/null +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -0,0 +1,16 @@ +from typing import Any, Union + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState +from core.workflow.entities.node_entities import SystemVariable +from models.account import Account +from models.model import EndUser +from models.workflow import Workflow + + +class WorkflowCycleStateManager: + _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] + _workflow: Workflow + _user: Union[Account, EndUser] + _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] + _workflow_system_variables: dict[SystemVariable, Any] \ No newline at end of file diff --git a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py b/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py new file mode 100644 index 0000000000..55e3e03173 --- /dev/null +++ b/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py @@ -0,0 +1,281 @@ +import json +import time +from collections.abc import Generator +from typing import Optional, Union + +from core.app.entities.queue_entities import ( + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, +) +from core.app.entities.task_entities import ( + IterationNodeCompletedStreamResponse, + IterationNodeNextStreamResponse, + IterationNodeStartStreamResponse, + NodeExecutionInfo, + WorkflowIterationState, +) +from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager +from core.workflow.entities.node_entities import NodeType +from extensions.ext_database import db +from models.workflow import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, +) + + +class WorkflowIterationCycleManage(WorkflowCycleStateManager): + _iteration_state: WorkflowIterationState = None + + def _init_iteration_state(self) -> WorkflowIterationState: + if not self._iteration_state: + self._iteration_state = WorkflowIterationState( + current_iterations={} + ) + + def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \ + -> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]: + """ + Handle iteration to stream response + :param task_id: task id + :param event: iteration event + :return: + """ + if isinstance(event, QueueIterationStartEvent): + return IterationNodeStartStreamResponse( + task_id=task_id, + workflow_run_id=self._task_state.workflow_run_id, + data=IterationNodeStartStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + created_at=int(time.time()), + extras={}, + inputs=event.inputs, + metadata=event.metadata + ) + ) + elif isinstance(event, QueueIterationNextEvent): + current_iteration = self._iteration_state.current_iterations[event.node_id] + + return IterationNodeNextStreamResponse( + task_id=task_id, + workflow_run_id=self._task_state.workflow_run_id, + data=IterationNodeNextStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=current_iteration.node_data.title, + index=event.index, + pre_iteration_output=event.output, + created_at=int(time.time()), + extras={} + ) + ) + elif isinstance(event, QueueIterationCompletedEvent): + current_iteration = self._iteration_state.current_iterations[event.node_id] + + return IterationNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=self._task_state.workflow_run_id, + data=IterationNodeCompletedStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=current_iteration.node_data.title, + outputs=event.outputs, + created_at=int(time.time()), + extras={}, + inputs=current_iteration.inputs, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + error=None, + elapsed_time=time.perf_counter() - current_iteration.started_at, + total_tokens=current_iteration.total_tokens, + finished_at=int(time.time()), + steps=current_iteration.current_index + ) + ) + + def _init_iteration_execution_from_workflow_run(self, + workflow_run: WorkflowRun, + node_id: str, + node_type: NodeType, + node_title: str, + node_run_index: int = 1, + inputs: Optional[dict] = None, + predecessor_node_id: Optional[str] = None + ) -> WorkflowNodeExecution: + workflow_node_execution = WorkflowNodeExecution( + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + workflow_id=workflow_run.workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + workflow_run_id=workflow_run.id, + predecessor_node_id=predecessor_node_id, + index=node_run_index, + node_id=node_id, + node_type=node_type.value, + inputs=json.dumps(inputs) if inputs else None, + title=node_title, + status=WorkflowNodeExecutionStatus.RUNNING.value, + created_by_role=workflow_run.created_by_role, + created_by=workflow_run.created_by, + execution_metadata=json.dumps({ + 'started_run_index': node_run_index + 1, + 'current_index': 0, + 'steps_boundary': [], + }) + ) + + db.session.add(workflow_node_execution) + db.session.commit() + db.session.refresh(workflow_node_execution) + db.session.close() + + return workflow_node_execution + + def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution: + if isinstance(event, QueueIterationStartEvent): + return self._handle_iteration_started(event) + elif isinstance(event, QueueIterationNextEvent): + return self._handle_iteration_next(event) + elif isinstance(event, QueueIterationCompletedEvent): + return self._handle_iteration_completed(event) + + def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution: + self._init_iteration_state() + + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() + workflow_node_execution = self._init_iteration_execution_from_workflow_run( + workflow_run=workflow_run, + node_id=event.node_id, + node_type=NodeType.ITERATION, + node_title=event.node_data.title, + node_run_index=event.node_run_index, + inputs=event.inputs, + predecessor_node_id=event.predecessor_node_id + ) + + latest_node_execution_info = NodeExecutionInfo( + workflow_node_execution_id=workflow_node_execution.id, + node_type=NodeType.ITERATION, + start_at=time.perf_counter() + ) + + self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.latest_node_execution_info = latest_node_execution_info + + self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data( + parent_iteration_id=None, + iteration_id=event.node_id, + current_index=0, + iteration_steps_boundary=[], + node_execution_id=workflow_node_execution.id, + started_at=time.perf_counter(), + inputs=event.inputs, + total_tokens=0, + node_data=event.node_data + ) + + db.session.close() + + return workflow_node_execution + + def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution: + if event.node_id not in self._iteration_state.current_iterations: + return + current_iteration = self._iteration_state.current_iterations[event.node_id] + current_iteration.current_index = event.index + current_iteration.iteration_steps_boundary.append(event.node_run_index) + workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_iteration.node_execution_id + ).first() + + original_node_execution_metadata = workflow_node_execution.execution_metadata_dict + if original_node_execution_metadata: + original_node_execution_metadata['current_index'] = event.index + original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary + original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens + workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) + + db.session.commit() + + db.session.close() + + def _handle_iteration_completed(self, event: QueueIterationCompletedEvent) -> WorkflowNodeExecution: + if event.node_id not in self._iteration_state.current_iterations: + return + + current_iteration = self._iteration_state.current_iterations[event.node_id] + workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_iteration.node_execution_id + ).first() + + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + workflow_node_execution.outputs = json.dumps(event.outputs) if event.outputs else None + workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at + + original_node_execution_metadata = workflow_node_execution.execution_metadata_dict + if original_node_execution_metadata: + original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary + original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens + workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) + + db.session.commit() + + # remove current iteration + self._iteration_state.current_iterations.pop(event.node_id, None) + + # set latest node execution info + latest_node_execution_info = NodeExecutionInfo( + workflow_node_execution_id=workflow_node_execution.id, + node_type=NodeType.ITERATION, + start_at=time.perf_counter() + ) + + self._task_state.latest_node_execution_info = latest_node_execution_info + + db.session.close() + + def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]: + """ + Handle iteration exception + """ + if not self._iteration_state or not self._iteration_state.current_iterations: + return + + for node_id, current_iteration in self._iteration_state.current_iterations.items(): + workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_iteration.node_execution_id + ).first() + + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at + + db.session.commit() + db.session.close() + + yield IterationNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=self._task_state.workflow_run_id, + data=IterationNodeCompletedStreamResponse.Data( + id=node_id, + node_id=node_id, + node_type=NodeType.ITERATION.value, + title=current_iteration.node_data.title, + outputs={}, + created_at=int(time.time()), + extras={}, + inputs=current_iteration.inputs, + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + elapsed_time=time.perf_counter() - current_iteration.started_at, + total_tokens=current_iteration.total_tokens, + finished_at=int(time.time()), + steps=current_iteration.current_index + ) + ) \ No newline at end of file diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 05719e5b8d..9a797c1c95 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -16,6 +16,7 @@ class ModelStatus(Enum): NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" NO_PERMISSION = "no-permission" + DISABLED = "disabled" class SimpleModelProviderEntity(BaseModel): @@ -43,12 +44,19 @@ class SimpleModelProviderEntity(BaseModel): ) -class ModelWithProviderEntity(ProviderModel): +class ProviderModelWithStatusEntity(ProviderModel): + """ + Model class for model response. + """ + status: ModelStatus + load_balancing_enabled: bool = False + + +class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ Model with provider entity. """ provider: SimpleModelProviderEntity - status: ModelStatus class DefaultModelProviderEntity(BaseModel): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 303034693d..ec1b2d0d48 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,6 +1,7 @@ import datetime import json import logging +from collections import defaultdict from collections.abc import Iterator from json import JSONDecodeError from typing import Optional @@ -8,7 +9,12 @@ from typing import Optional from pydantic import BaseModel from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus +from core.entities.provider_entities import ( + CustomConfiguration, + ModelSettings, + SystemConfiguration, + SystemConfigurationStatus, +) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import FetchFrom, ModelType @@ -22,7 +28,14 @@ from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.model_provider import ModelProvider from extensions.ext_database import db -from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider +from models.provider import ( + LoadBalancingModelConfig, + Provider, + ProviderModel, + ProviderModelSetting, + ProviderType, + TenantPreferredModelProvider, +) logger = logging.getLogger(__name__) @@ -39,6 +52,7 @@ class ProviderConfiguration(BaseModel): using_provider_type: ProviderType system_configuration: SystemConfiguration custom_configuration: CustomConfiguration + model_settings: list[ModelSettings] def __init__(self, **data): super().__init__(**data) @@ -62,6 +76,14 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ + if self.model_settings: + # check if model is disabled by admin + for model_setting in self.model_settings: + if (model_setting.model_type == model_type + and model_setting.model == model): + if not model_setting.enabled: + raise ValueError(f'Model {model} is disabled.') + if self.using_provider_type == ProviderType.SYSTEM: restrict_models = [] for quota_configuration in self.system_configuration.quota_configurations: @@ -80,15 +102,17 @@ class ProviderConfiguration(BaseModel): return copy_credentials else: + credentials = None if self.custom_configuration.models: for model_configuration in self.custom_configuration.models: if model_configuration.model_type == model_type and model_configuration.model == model: - return model_configuration.credentials + credentials = model_configuration.credentials + break if self.custom_configuration.provider: - return self.custom_configuration.provider.credentials - else: - return None + credentials = self.custom_configuration.provider.credentials + + return credentials def get_system_configuration_status(self) -> SystemConfigurationStatus: """ @@ -130,7 +154,7 @@ class ProviderConfiguration(BaseModel): return credentials # Obfuscate credentials - return self._obfuscated_credentials( + return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] @@ -151,7 +175,7 @@ class ProviderConfiguration(BaseModel): ).first() # Get provider credential secret variables - provider_credential_secret_variables = self._extract_secret_variables( + provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] ) @@ -274,7 +298,7 @@ class ProviderConfiguration(BaseModel): return credentials # Obfuscate credentials - return self._obfuscated_credentials( + return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [] @@ -302,7 +326,7 @@ class ProviderConfiguration(BaseModel): ).first() # Get provider credential secret variables - provider_credential_secret_variables = self._extract_secret_variables( + provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [] ) @@ -402,6 +426,160 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache.delete() + def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Enable model. + :param model_type: model type + :param model: model name + :return: + """ + model_setting = db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model + ).first() + + if model_setting: + model_setting.enabled = True + model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=True + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + + def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Disable model. + :param model_type: model type + :param model: model name + :return: + """ + model_setting = db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model + ).first() + + if model_setting: + model_setting.enabled = False + model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=False + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + + def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]: + """ + Get provider model setting. + :param model_type: model type + :param model: model name + :return: + """ + return db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model + ).first() + + def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Enable model load balancing. + :param model_type: model type + :param model: model name + :return: + """ + load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model + ).count() + + if load_balancing_config_count <= 1: + raise ValueError('Model load balancing configuration must be more than 1.') + + model_setting = db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model + ).first() + + if model_setting: + model_setting.load_balancing_enabled = True + model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=True + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + + def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Disable model load balancing. + :param model_type: model type + :param model: model name + :return: + """ + model_setting = db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model + ).first() + + if model_setting: + model_setting.load_balancing_enabled = False + model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=False + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + def get_provider_instance(self) -> ModelProvider: """ Get provider instance. @@ -453,7 +631,7 @@ class ProviderConfiguration(BaseModel): db.session.commit() - def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: + def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ Extract secret input form variables. @@ -467,7 +645,7 @@ class ProviderConfiguration(BaseModel): return secret_input_form_variables - def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: + def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: """ Obfuscated credentials. @@ -476,7 +654,7 @@ class ProviderConfiguration(BaseModel): :return: """ # Get provider credential secret variables - credential_secret_variables = self._extract_secret_variables( + credential_secret_variables = self.extract_secret_variables( credential_form_schemas ) @@ -522,15 +700,22 @@ class ProviderConfiguration(BaseModel): else: model_types = provider_instance.get_provider_schema().supported_model_types + # Group model settings by model type and model + model_setting_map = defaultdict(dict) + for model_setting in self.model_settings: + model_setting_map[model_setting.model_type][model_setting.model] = model_setting + if self.using_provider_type == ProviderType.SYSTEM: provider_models = self._get_system_provider_models( model_types=model_types, - provider_instance=provider_instance + provider_instance=provider_instance, + model_setting_map=model_setting_map ) else: provider_models = self._get_custom_provider_models( model_types=model_types, - provider_instance=provider_instance + provider_instance=provider_instance, + model_setting_map=model_setting_map ) if only_active: @@ -541,18 +726,27 @@ class ProviderConfiguration(BaseModel): def _get_system_provider_models(self, model_types: list[ModelType], - provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ + -> list[ModelWithProviderEntity]: """ Get system provider models. :param model_types: model types :param provider_instance: provider instance + :param model_setting_map: model setting map :return: """ provider_models = [] for model_type in model_types: - provider_models.extend( - [ + for m in provider_instance.models(model_type): + status = ModelStatus.ACTIVE + if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: + model_setting = model_setting_map[m.model_type][m.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + provider_models.append( ModelWithProviderEntity( model=m.model, label=m.label, @@ -562,11 +756,9 @@ class ProviderConfiguration(BaseModel): model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE + status=status ) - for m in provider_instance.models(model_type) - ] - ) + ) if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] @@ -586,7 +778,8 @@ class ProviderConfiguration(BaseModel): break if should_use_custom_model: - if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: + if original_provider_configurate_methods[self.provider.provider] == [ + ConfigurateMethod.CUSTOMIZABLE_MODEL]: # only customizable model for restrict_model in restrict_models: copy_credentials = self.system_configuration.credentials.copy() @@ -611,6 +804,13 @@ class ProviderConfiguration(BaseModel): if custom_model_schema.model_type not in model_types: continue + status = ModelStatus.ACTIVE + if (custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, @@ -621,7 +821,7 @@ class ProviderConfiguration(BaseModel): model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE + status=status ) ) @@ -632,16 +832,20 @@ class ProviderConfiguration(BaseModel): m.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: m.status = ModelStatus.QUOTA_EXCEEDED + return provider_models def _get_custom_provider_models(self, model_types: list[ModelType], - provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ + -> list[ModelWithProviderEntity]: """ Get custom provider models. :param model_types: model types :param provider_instance: provider instance + :param model_setting_map: model setting map :return: """ provider_models = [] @@ -656,6 +860,16 @@ class ProviderConfiguration(BaseModel): models = provider_instance.models(model_type) for m in models: + status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE + load_balancing_enabled = False + if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: + model_setting = model_setting_map[m.model_type][m.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + if len(model_setting.load_balancing_configs) > 1: + load_balancing_enabled = True + provider_models.append( ModelWithProviderEntity( model=m.model, @@ -666,7 +880,8 @@ class ProviderConfiguration(BaseModel): model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE + status=status, + load_balancing_enabled=load_balancing_enabled ) ) @@ -690,6 +905,17 @@ class ProviderConfiguration(BaseModel): if not custom_model_schema: continue + status = ModelStatus.ACTIVE + load_balancing_enabled = False + if (custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + if len(model_setting.load_balancing_configs) > 1: + load_balancing_enabled = True + provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, @@ -700,7 +926,8 @@ class ProviderConfiguration(BaseModel): model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE + status=status, + load_balancing_enabled=load_balancing_enabled ) ) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 114dfaf911..1eaa6ea02c 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -72,3 +72,22 @@ class CustomConfiguration(BaseModel): """ provider: Optional[CustomProviderConfiguration] = None models: list[CustomModelConfiguration] = [] + + +class ModelLoadBalancingConfiguration(BaseModel): + """ + Class for model load balancing configuration. + """ + id: str + name: str + credentials: dict + + +class ModelSettings(BaseModel): + """ + Model class for model settings. + """ + model: str + model_type: ModelType + enabled: bool = True + load_balancing_configs: list[ModelLoadBalancingConfiguration] = [] diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index d2ec555d6c..3a37c6492e 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -7,7 +7,7 @@ from typing import Any, Optional from pydantic import BaseModel -from core.utils.position_helper import sort_to_dict_by_position_map +from core.helper.position_helper import sort_to_dict_by_position_map class ExtensionModule(enum.Enum): diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 06f21c880a..52498eb871 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -42,6 +42,8 @@ class MessageFileParser: raise ValueError('Invalid file url') if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'): raise ValueError('Missing file upload_file_id') + if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'): + raise ValueError('Missing file tool_file_id') # transform files to file objs type_file_objs = self._to_file_objs(files, file_extra_config) @@ -149,12 +151,21 @@ class MessageFileParser: """ if isinstance(file, dict): transfer_method = FileTransferMethod.value_of(file.get('transfer_method')) + if transfer_method != FileTransferMethod.TOOL_FILE: + return FileVar( + tenant_id=self.tenant_id, + type=FileType.value_of(file.get('type')), + transfer_method=transfer_method, + url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=file_extra_config + ) return FileVar( tenant_id=self.tenant_id, type=FileType.value_of(file.get('type')), transfer_method=transfer_method, - url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + url=None, + related_id=file.get('tool_file_id'), extra_config=file_extra_config ) else: diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index 974fde178b..9e454f08d4 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -77,4 +77,4 @@ class UploadFileParser: return False current_time = int(time.time()) - return current_time - int(timestamp) <= 300 # expired after 5 minutes + return current_time - int(timestamp) <= current_app.config.get('FILES_ACCESS_TIMEOUT') diff --git a/api/controllers/console/enterprise/__init__.py b/api/core/helper/code_executor/__init__.py similarity index 100% rename from api/controllers/console/enterprise/__init__.py rename to api/core/helper/code_executor/__init__.py diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 063a21b192..37573c88ce 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,13 +1,21 @@ +import logging +import time +from enum import Enum +from threading import Lock from typing import Literal, Optional -from httpx import post +from httpx import get, post from pydantic import BaseModel from yarl import URL from config import get_env -from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer -from core.helper.code_executor.jinja2_transformer import Jinja2TemplateTransformer -from core.helper.code_executor.python_transformer import PythonTemplateTransformer +from core.helper.code_executor.entities import CodeDependency +from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer +from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer +from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer +from core.helper.code_executor.template_transformer import TemplateTransformer + +logger = logging.getLogger(__name__) # Code Executor CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT') @@ -28,9 +36,38 @@ class CodeExecutionResponse(BaseModel): data: Data +class CodeLanguage(str, Enum): + PYTHON3 = 'python3' + JINJA2 = 'jinja2' + JAVASCRIPT = 'javascript' + + class CodeExecutor: + dependencies_cache = {} + dependencies_cache_lock = Lock() + + code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = { + CodeLanguage.PYTHON3: Python3TemplateTransformer, + CodeLanguage.JINJA2: Jinja2TemplateTransformer, + CodeLanguage.JAVASCRIPT: NodeJsTemplateTransformer, + } + + code_language_to_running_language = { + CodeLanguage.JAVASCRIPT: 'nodejs', + CodeLanguage.JINJA2: CodeLanguage.PYTHON3, + CodeLanguage.PYTHON3: CodeLanguage.PYTHON3, + } + + supported_dependencies_languages: set[CodeLanguage] = { + CodeLanguage.PYTHON3 + } + @classmethod - def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], preload: str, code: str) -> str: + def execute_code(cls, + language: Literal['python3', 'javascript', 'jinja2'], + preload: str, + code: str, + dependencies: Optional[list[CodeDependency]] = None) -> str: """ Execute code :param language: code language @@ -44,13 +81,15 @@ class CodeExecutor: } data = { - 'language': 'python3' if language == 'jinja2' else - 'nodejs' if language == 'javascript' else - 'python3' if language == 'python3' else None, + 'language': cls.code_language_to_running_language.get(language), 'code': code, - 'preload': preload + 'preload': preload, + 'enable_network': True } + if dependencies: + data['dependencies'] = [dependency.dict() for dependency in dependencies] + try: response = post(str(url), json=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT) if response.status_code == 503: @@ -60,7 +99,9 @@ class CodeExecutor: except CodeExecutionException as e: raise e except Exception as e: - raise CodeExecutionException('Failed to execute code, this is likely a network issue, please check if the sandbox service is running') + raise CodeExecutionException('Failed to execute code, which is likely a network issue,' + ' please check if the sandbox service is running.' + f' ( Error: {str(e)} )') try: response = response.json() @@ -78,7 +119,7 @@ class CodeExecutor: return response.data.stdout @classmethod - def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: + def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict: """ Execute code :param language: code language @@ -86,21 +127,71 @@ class CodeExecutor: :param inputs: inputs :return: """ - template_transformer = None - if language == 'python3': - template_transformer = PythonTemplateTransformer - elif language == 'jinja2': - template_transformer = Jinja2TemplateTransformer - elif language == 'javascript': - template_transformer = NodeJsTemplateTransformer - else: - raise CodeExecutionException('Unsupported language') + template_transformer = cls.code_template_transformers.get(language) + if not template_transformer: + raise CodeExecutionException(f'Unsupported language {language}') - runner, preload = template_transformer.transform_caller(code, inputs) + runner, preload, dependencies = template_transformer.transform_caller(code, inputs, dependencies) try: - response = cls.execute_code(language, preload, runner) + response = cls.execute_code(language, preload, runner, dependencies) except CodeExecutionException as e: raise e - return template_transformer.transform_response(response) \ No newline at end of file + return template_transformer.transform_response(response) + + @classmethod + def list_dependencies(cls, language: str) -> list[CodeDependency]: + if language not in cls.supported_dependencies_languages: + return [] + + with cls.dependencies_cache_lock: + if language in cls.dependencies_cache: + # check expiration + dependencies = cls.dependencies_cache[language] + if dependencies['expiration'] > time.time(): + return dependencies['data'] + # remove expired cache + del cls.dependencies_cache[language] + + dependencies = cls._get_dependencies(language) + with cls.dependencies_cache_lock: + cls.dependencies_cache[language] = { + 'data': dependencies, + 'expiration': time.time() + 60 + } + + return dependencies + + @classmethod + def _get_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]: + """ + List dependencies + """ + url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'dependencies' + + headers = { + 'X-Api-Key': CODE_EXECUTION_API_KEY + } + + running_language = cls.code_language_to_running_language.get(language) + if isinstance(running_language, Enum): + running_language = running_language.value + + data = { + 'language': running_language, + } + + try: + response = get(str(url), params=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT) + if response.status_code != 200: + raise Exception(f'Failed to list dependencies, got status code {response.status_code}, please check if the sandbox service is running') + response = response.json() + dependencies = response.get('data', {}).get('dependencies', []) + return [ + CodeDependency(**dependency) for dependency in dependencies + if dependency.get('name') not in Python3TemplateTransformer.get_standard_packages() + ] + except Exception as e: + logger.exception(f'Failed to list dependencies: {e}') + return [] \ No newline at end of file diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py new file mode 100644 index 0000000000..b76e15eeab --- /dev/null +++ b/api/core/helper/code_executor/code_node_provider.py @@ -0,0 +1,55 @@ +from abc import abstractmethod + +from pydantic import BaseModel + +from core.helper.code_executor.code_executor import CodeExecutor + + +class CodeNodeProvider(BaseModel): + @staticmethod + @abstractmethod + def get_language() -> str: + pass + + @classmethod + def is_accept_language(cls, language: str) -> bool: + return language == cls.get_language() + + @classmethod + @abstractmethod + def get_default_code(cls) -> str: + """ + get default code in specific programming language for the code node + """ + pass + + @classmethod + def get_default_available_packages(cls) -> list[dict]: + return [p.dict() for p in CodeExecutor.list_dependencies(cls.get_language())] + + @classmethod + def get_default_config(cls) -> dict: + return { + "type": "code", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + }, + { + "variable": "arg2", + "value_selector": [] + } + ], + "code_language": cls.get_language(), + "code": cls.get_default_code(), + "outputs": { + "result": { + "type": "string", + "children": None + } + } + }, + "available_dependencies": cls.get_default_available_packages(), + } diff --git a/api/core/helper/code_executor/entities.py b/api/core/helper/code_executor/entities.py new file mode 100644 index 0000000000..cc10288521 --- /dev/null +++ b/api/core/helper/code_executor/entities.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class CodeDependency(BaseModel): + name: str + version: str diff --git a/api/core/rerank/__init__.py b/api/core/helper/code_executor/javascript/__init__.py similarity index 100% rename from api/core/rerank/__init__.py rename to api/core/helper/code_executor/javascript/__init__.py diff --git a/api/core/helper/code_executor/javascript/javascript_code_provider.py b/api/core/helper/code_executor/javascript/javascript_code_provider.py new file mode 100644 index 0000000000..a157fcc6d1 --- /dev/null +++ b/api/core/helper/code_executor/javascript/javascript_code_provider.py @@ -0,0 +1,21 @@ +from textwrap import dedent + +from core.helper.code_executor.code_executor import CodeLanguage +from core.helper.code_executor.code_node_provider import CodeNodeProvider + + +class JavascriptCodeProvider(CodeNodeProvider): + @staticmethod + def get_language() -> str: + return CodeLanguage.JAVASCRIPT + + @classmethod + def get_default_code(cls) -> str: + return dedent( + """ + function main({arg1, arg2}) { + return { + result: arg1 + arg2 + } + } + """) diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py new file mode 100644 index 0000000000..a4d2551972 --- /dev/null +++ b/api/core/helper/code_executor/javascript/javascript_transformer.py @@ -0,0 +1,25 @@ +from textwrap import dedent + +from core.helper.code_executor.template_transformer import TemplateTransformer + + +class NodeJsTemplateTransformer(TemplateTransformer): + @classmethod + def get_runner_script(cls) -> str: + runner_script = dedent( + f""" + // declare main function + {cls._code_placeholder} + + // decode and prepare input object + var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8')) + + // execute main function + var output_obj = main(inputs_obj) + + // convert output to json and print + var output_json = JSON.stringify(output_obj) + var result = `<>${{output_json}}<>` + console.log(result) + """) + return runner_script diff --git a/api/core/helper/code_executor/javascript_transformer.py b/api/core/helper/code_executor/javascript_transformer.py deleted file mode 100644 index 29b8e06e86..0000000000 --- a/api/core/helper/code_executor/javascript_transformer.py +++ /dev/null @@ -1,54 +0,0 @@ -import json -import re - -from core.helper.code_executor.template_transformer import TemplateTransformer - -NODEJS_RUNNER = """// declare main function here -{{code}} - -// execute main function, and return the result -// inputs is a dict, unstructured inputs -output = main({{inputs}}) - -// convert output to json and print -output = JSON.stringify(output) - -result = `<>${output}<>` - -console.log(result) -""" - -NODEJS_PRELOAD = """""" - -class NodeJsTemplateTransformer(TemplateTransformer): - @classmethod - def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: - """ - Transform code to python runner - :param code: code - :param inputs: inputs - :return: - """ - - # transform inputs to json string - inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False) - - # replace code and inputs - runner = NODEJS_RUNNER.replace('{{code}}', code) - runner = runner.replace('{{inputs}}', inputs_str) - - return runner, NODEJS_PRELOAD - - @classmethod - def transform_response(cls, response: str) -> dict: - """ - Transform response to dict - :param response: response - :return: - """ - # extract result - result = re.search(r'<>(.*)<>', response, re.DOTALL) - if not result: - raise ValueError('Failed to parse result') - result = result.group(1) - return json.loads(result) diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/helper/code_executor/jinja2/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/__init__.py rename to api/core/helper/code_executor/jinja2/__init__.py diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py new file mode 100644 index 0000000000..63f48a56e2 --- /dev/null +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -0,0 +1,17 @@ +from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage + + +class Jinja2Formatter: + @classmethod + def format(cls, template: str, inputs: str) -> str: + """ + Format template + :param template: template + :param inputs: inputs + :return: + """ + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, code=template, inputs=inputs + ) + + return result['result'] diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py new file mode 100644 index 0000000000..a8f8095d52 --- /dev/null +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -0,0 +1,64 @@ +from textwrap import dedent + +from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer +from core.helper.code_executor.template_transformer import TemplateTransformer + + +class Jinja2TemplateTransformer(TemplateTransformer): + @classmethod + def get_standard_packages(cls) -> set[str]: + return {'jinja2'} | Python3TemplateTransformer.get_standard_packages() + + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + return { + 'result': cls.extract_result_str_from_response(response) + } + + @classmethod + def get_runner_script(cls) -> str: + runner_script = dedent(f""" + # declare main function + def main(**inputs): + import jinja2 + template = jinja2.Template('''{cls._code_placeholder}''') + return template.render(**inputs) + + import json + from base64 import b64decode + + # decode and prepare input dict + inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8')) + + # execute main function + output = main(**inputs_obj) + + # convert output and print + result = f'''<>{{output}}<>''' + print(result) + + """) + return runner_script + + @classmethod + def get_preload_script(cls) -> str: + preload_script = dedent(""" + import jinja2 + from base64 import b64decode + + def _jinja2_preload_(): + # prepare jinja2 environment, load template and render before to avoid sandbox issue + template = jinja2.Template('{{s}}') + template.render(s='a') + + if __name__ == '__main__': + _jinja2_preload_() + + """) + + return preload_script diff --git a/api/core/helper/code_executor/jinja2_transformer.py b/api/core/helper/code_executor/jinja2_transformer.py deleted file mode 100644 index 27a3579493..0000000000 --- a/api/core/helper/code_executor/jinja2_transformer.py +++ /dev/null @@ -1,92 +0,0 @@ -import json -import re -from base64 import b64encode - -from core.helper.code_executor.template_transformer import TemplateTransformer - -PYTHON_RUNNER = """ -import jinja2 -from json import loads -from base64 import b64decode - -template = jinja2.Template('''{{code}}''') - -def main(**inputs): - return template.render(**inputs) - -# execute main function, and return the result -inputs = b64decode('{{inputs}}').decode('utf-8') -output = main(**loads(inputs)) - -result = f'''<>{output}<>''' - -print(result) - -""" - -JINJA2_PRELOAD_TEMPLATE = """{% set fruits = ['Apple'] %} -{{ 'a' }} -{% for fruit in fruits %} -
  • {{ fruit }}
  • -{% endfor %} -{% if fruits|length > 1 %} -1 -{% endif %} -{% for i in range(5) %} - {% if i == 3 %}{{ i }}{% else %}{% endif %} -{% endfor %} - {% for i in range(3) %} - {{ i + 1 }} - {% endfor %} -{% macro say_hello() %}a{{ 'b' }}{% endmacro %} -{{ s }}{{ say_hello() }}""" - -JINJA2_PRELOAD = f""" -import jinja2 -from base64 import b64decode - -def _jinja2_preload_(): - # prepare jinja2 environment, load template and render before to avoid sandbox issue - template = jinja2.Template('''{JINJA2_PRELOAD_TEMPLATE}''') - template.render(s='a') - -if __name__ == '__main__': - _jinja2_preload_() - -""" - - -class Jinja2TemplateTransformer(TemplateTransformer): - @classmethod - def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: - """ - Transform code to python runner - :param code: code - :param inputs: inputs - :return: - """ - - inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8') - - # transform jinja2 template to python code - runner = PYTHON_RUNNER.replace('{{code}}', code) - runner = runner.replace('{{inputs}}', inputs_str) - - return runner, JINJA2_PRELOAD - - @classmethod - def transform_response(cls, response: str) -> dict: - """ - Transform response to dict - :param response: response - :return: - """ - # extract result - result = re.search(r'<>(.*)<>', response, re.DOTALL) - if not result: - raise ValueError('Failed to parse result') - result = result.group(1) - - return { - 'result': result - } diff --git a/api/core/application_manager.py b/api/core/helper/code_executor/python3/__init__.py similarity index 100% rename from api/core/application_manager.py rename to api/core/helper/code_executor/python3/__init__.py diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py new file mode 100644 index 0000000000..efcb8a9d1e --- /dev/null +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -0,0 +1,20 @@ +from textwrap import dedent + +from core.helper.code_executor.code_executor import CodeLanguage +from core.helper.code_executor.code_node_provider import CodeNodeProvider + + +class Python3CodeProvider(CodeNodeProvider): + @staticmethod + def get_language() -> str: + return CodeLanguage.PYTHON3 + + @classmethod + def get_default_code(cls) -> str: + return dedent( + """ + def main(arg1: int, arg2: int) -> dict: + return { + "result": arg1 + arg2, + } + """) diff --git a/api/core/helper/code_executor/python3/python3_transformer.py b/api/core/helper/code_executor/python3/python3_transformer.py new file mode 100644 index 0000000000..4a5fa35093 --- /dev/null +++ b/api/core/helper/code_executor/python3/python3_transformer.py @@ -0,0 +1,51 @@ +from textwrap import dedent + +from core.helper.code_executor.template_transformer import TemplateTransformer + + +class Python3TemplateTransformer(TemplateTransformer): + @classmethod + def get_standard_packages(cls) -> set[str]: + return { + 'base64', + 'binascii', + 'collections', + 'datetime', + 'functools', + 'hashlib', + 'hmac', + 'itertools', + 'json', + 'math', + 'operator', + 'os', + 'random', + 're', + 'string', + 'sys', + 'time', + 'traceback', + 'uuid', + } + + @classmethod + def get_runner_script(cls) -> str: + runner_script = dedent(f""" + # declare main function + {cls._code_placeholder} + + import json + from base64 import b64decode + + # decode and prepare input dict + inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8')) + + # execute main function + output_obj = main(**inputs_obj) + + # convert output to json and print + output_json = json.dumps(output_obj, indent=4) + result = f'''<>{{output_json}}<>''' + print(result) + """) + return runner_script diff --git a/api/core/helper/code_executor/python_transformer.py b/api/core/helper/code_executor/python_transformer.py deleted file mode 100644 index f44acbb9bf..0000000000 --- a/api/core/helper/code_executor/python_transformer.py +++ /dev/null @@ -1,82 +0,0 @@ -import json -import re -from base64 import b64encode - -from core.helper.code_executor.template_transformer import TemplateTransformer - -PYTHON_RUNNER = """# declare main function here -{{code}} - -from json import loads, dumps -from base64 import b64decode - -# execute main function, and return the result -# inputs is a dict, and it -inputs = b64decode('{{inputs}}').decode('utf-8') -output = main(**json.loads(inputs)) - -# convert output to json and print -output = dumps(output, indent=4) - -result = f'''<> -{output} -<>''' - -print(result) -""" - -PYTHON_PRELOAD = """ -# prepare general imports -import json -import datetime -import math -import random -import re -import string -import sys -import time -import traceback -import uuid -import os -import base64 -import hashlib -import hmac -import binascii -import collections -import functools -import operator -import itertools -""" - -class PythonTemplateTransformer(TemplateTransformer): - @classmethod - def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: - """ - Transform code to python runner - :param code: code - :param inputs: inputs - :return: - """ - - # transform inputs to json string - inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8') - - # replace code and inputs - runner = PYTHON_RUNNER.replace('{{code}}', code) - runner = runner.replace('{{inputs}}', inputs_str) - - return runner, PYTHON_PRELOAD - - @classmethod - def transform_response(cls, response: str) -> dict: - """ - Transform response to dict - :param response: response - :return: - """ - # extract result - result = re.search(r'<>(.*?)<>', response, re.DOTALL) - if not result: - raise ValueError('Failed to parse result') - result = result.group(1) - return json.loads(result) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index c3564afd04..39af803f6e 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -1,24 +1,87 @@ +import json +import re from abc import ABC, abstractmethod +from base64 import b64encode +from typing import Optional + +from pydantic import BaseModel + +from core.helper.code_executor.entities import CodeDependency -class TemplateTransformer(ABC): +class TemplateTransformer(ABC, BaseModel): + _code_placeholder: str = '{{code}}' + _inputs_placeholder: str = '{{inputs}}' + _result_tag: str = '<>' + @classmethod - @abstractmethod - def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: + def get_standard_packages(cls) -> set[str]: + return set() + + @classmethod + def transform_caller(cls, code: str, inputs: dict, + dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]: """ Transform code to python runner :param code: code :param inputs: inputs :return: runner, preload """ - pass - + runner_script = cls.assemble_runner_script(code, inputs) + preload_script = cls.get_preload_script() + + packages = dependencies or [] + standard_packages = cls.get_standard_packages() + for package in standard_packages: + if package not in packages: + packages.append(CodeDependency(name=package, version='')) + packages = list({dep.name: dep for dep in packages if dep.name}.values()) + + return runner_script, preload_script, packages + + @classmethod + def extract_result_str_from_response(cls, response: str) -> str: + result = re.search(rf'{cls._result_tag}(.*){cls._result_tag}', response, re.DOTALL) + if not result: + raise ValueError('Failed to parse result') + result = result.group(1) + return result + @classmethod - @abstractmethod def transform_response(cls, response: str) -> dict: """ Transform response to dict :param response: response :return: """ - pass \ No newline at end of file + return json.loads(cls.extract_result_str_from_response(response)) + + @classmethod + @abstractmethod + def get_runner_script(cls) -> str: + """ + Get runner script + """ + pass + + @classmethod + def serialize_inputs(cls, inputs: dict) -> str: + inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode() + input_base64_encoded = b64encode(inputs_json_str).decode('utf-8') + return input_base64_encoded + + @classmethod + def assemble_runner_script(cls, code: str, inputs: dict) -> str: + # assemble runner script + script = cls.get_runner_script() + script = script.replace(cls._code_placeholder, code) + inputs_str = cls.serialize_inputs(inputs) + script = script.replace(cls._inputs_placeholder, inputs_str) + return script + + @classmethod + def get_preload_script(cls) -> str: + """ + Get preload script + """ + return '' diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 81e589f65b..29cb4acc7d 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -9,6 +9,7 @@ from extensions.ext_redis import redis_client class ProviderCredentialsCacheType(Enum): PROVIDER = "provider" MODEL = "provider_model" + LOAD_BALANCING_MODEL = "load_balancing_provider_model" class ProviderCredentialsCache: diff --git a/api/core/utils/module_import_helper.py b/api/core/helper/module_import_helper.py similarity index 100% rename from api/core/utils/module_import_helper.py rename to api/core/helper/module_import_helper.py diff --git a/api/core/utils/position_helper.py b/api/core/helper/position_helper.py similarity index 75% rename from api/core/utils/position_helper.py rename to api/core/helper/position_helper.py index e038390e09..689ab194a7 100644 --- a/api/core/utils/position_helper.py +++ b/api/core/helper/position_helper.py @@ -1,10 +1,9 @@ -import logging import os from collections import OrderedDict from collections.abc import Callable from typing import Any, AnyStr -import yaml +from core.tools.utils.yaml_utils import load_yaml_file def get_position_map( @@ -17,21 +16,15 @@ def get_position_map( :param file_name: the YAML file name, default to '_position.yaml' :return: a dict with name as key and index as value """ - try: - position_file_name = os.path.join(folder_path, file_name) - if not os.path.exists(position_file_name): - return {} - - with open(position_file_name, encoding='utf-8') as f: - positions = yaml.safe_load(f) - position_map = {} - for index, name in enumerate(positions): - if name and isinstance(name, str): - position_map[name.strip()] = index - return position_map - except: - logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.') - return {} + position_file_name = os.path.join(folder_path, file_name) + positions = load_yaml_file(position_file_name, ignore_error=True) + position_map = {} + index = 0 + for _, name in enumerate(positions): + if name and isinstance(name, str): + position_map[name.strip()] = index + index += 1 + return position_map def sort_by_position_map( diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index c44d4717e6..276c8a34e7 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -42,6 +42,20 @@ def delete(url, *args, **kwargs): if kwargs['follow_redirects']: kwargs['allow_redirects'] = kwargs['follow_redirects'] kwargs.pop('follow_redirects') + if 'timeout' in kwargs: + timeout = kwargs['timeout'] + if timeout is None: + kwargs.pop('timeout') + elif isinstance(timeout, tuple): + # check length of tuple + if len(timeout) == 2: + kwargs['timeout'] = timeout + elif len(timeout) == 1: + kwargs['timeout'] = timeout[0] + elif len(timeout) > 2: + kwargs['timeout'] = (timeout[0], timeout[1]) + else: + kwargs['timeout'] = (timeout, timeout) return _delete(url=url, *args, proxies=requests_proxies, **kwargs) def head(url, *args, **kwargs): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 51e77393d6..7fa1d7d4b9 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -12,7 +12,6 @@ from flask import Flask, current_app from flask_login import current_user from sqlalchemy.orm.exc import ObjectDeletedError -from core.docstore.dataset_docstore import DatasetDocumentStore from core.errors.error import ProviderTokenNotInitError from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager @@ -20,12 +19,16 @@ from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document -from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter -from core.splitter.text_splitter import TextSplitter +from core.rag.splitter.fixed_text_splitter import ( + EnhanceRecursiveCharacterTextSplitter, + FixedRecursiveCharacterTextSplitter, +) +from core.rag.splitter.text_splitter import TextSplitter from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -283,11 +286,7 @@ class IndexingRunner: if len(preview_texts) < 5: preview_texts.append(document.page_content) if indexing_technique == 'high_quality' or embedding_model_instance: - embedding_model_type_instance = embedding_model_instance.model_type_instance - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) - tokens += embedding_model_type_instance.get_num_tokens( - model=embedding_model_instance.model, - credentials=embedding_model_instance.credentials, + tokens += embedding_model_instance.get_text_embedding_num_tokens( texts=[self.filter_string(document.page_content)] ) @@ -411,14 +410,15 @@ class IndexingRunner: # The user-defined segmentation rule rules = json.loads(processing_rule.rules) segmentation = rules["segmentation"] - if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000: - raise ValueError("Custom segment length should be between 50 and 1000.") + max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH']) + if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: + raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") separator = segmentation["separator"] if separator: separator = separator.replace('\\n', '\n') - if 'chunk_overlap' in segmentation and segmentation['chunk_overlap']: + if segmentation.get('chunk_overlap'): chunk_overlap = segmentation['chunk_overlap'] else: chunk_overlap = 0 @@ -427,7 +427,7 @@ class IndexingRunner: chunk_size=segmentation["max_tokens"], chunk_overlap=chunk_overlap, fixed_separator=separator, - separators=["\n\n", "。", ".", " ", ""], + separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance ) else: @@ -435,7 +435,7 @@ class IndexingRunner: character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], - separators=["\n\n", "。", ".", " ", ""], + separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance ) @@ -654,10 +654,6 @@ class IndexingRunner: tokens = 0 chunk_size = 10 - embedding_model_type_instance = None - if embedding_model_instance: - embedding_model_type_instance = embedding_model_instance.model_type_instance - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) # create keyword index create_keyword_thread = threading.Thread(target=self._process_keyword_index, args=(current_app._get_current_object(), @@ -670,8 +666,7 @@ class IndexingRunner: chunk_documents = documents[i:i + chunk_size] futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor, chunk_documents, dataset, - dataset_document, embedding_model_instance, - embedding_model_type_instance)) + dataset_document, embedding_model_instance)) for future in futures: tokens += future.result() @@ -712,7 +707,7 @@ class IndexingRunner: db.session.commit() def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document, - embedding_model_instance, embedding_model_type_instance): + embedding_model_instance): with flask_app.app_context(): # check document is paused self._check_document_paused_status(dataset_document.id) @@ -720,9 +715,7 @@ class IndexingRunner: tokens = 0 if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance: tokens += sum( - embedding_model_type_instance.get_num_tokens( - embedding_model_instance.model, - embedding_model_instance.credentials, + embedding_model_instance.get_text_embedding_num_tokens( [document.page_content] ) for document in chunk_documents diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index cd0b2508d4..6b53104c70 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -9,8 +9,6 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers import model_provider_factory from extensions.ext_database import db from models.model import AppMode, Conversation, Message @@ -78,12 +76,7 @@ class TokenBufferMemory: return [] # prune the chat message if it exceeds the max token limit - provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider) - model_type_instance = provider_instance.get_model_instance(ModelType.LLM) - - curr_message_tokens = model_type_instance.get_num_tokens( - self.model_instance.model, - self.model_instance.credentials, + curr_message_tokens = self.model_instance.get_llm_num_tokens( prompt_messages ) @@ -91,9 +84,7 @@ class TokenBufferMemory: pruned_memory = [] while curr_message_tokens > max_token_limit and prompt_messages: pruned_memory.append(prompt_messages.pop(0)) - curr_message_tokens = model_type_instance.get_num_tokens( - self.model_instance.model, - self.model_instance.credentials, + curr_message_tokens = self.model_instance.get_llm_num_tokens( prompt_messages ) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 8c06339927..8da8442e60 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,7 +1,10 @@ +import logging +import os from collections.abc import Generator from typing import IO, Optional, Union, cast -from core.entities.provider_configuration import ProviderModelBundle +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult @@ -9,6 +12,7 @@ from core.model_runtime.entities.message_entities import PromptMessage, PromptMe from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.moderation_model import ModerationModel from core.model_runtime.model_providers.__base.rerank_model import RerankModel @@ -16,6 +20,10 @@ from core.model_runtime.model_providers.__base.speech2text_model import Speech2T from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.provider_manager import ProviderManager +from extensions.ext_redis import redis_client +from models.provider import ProviderType + +logger = logging.getLogger(__name__) class ModelInstance: @@ -29,6 +37,12 @@ class ModelInstance: self.provider = provider_model_bundle.configuration.provider.provider self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) self.model_type_instance = self.provider_model_bundle.model_type_instance + self.load_balancing_manager = self._get_load_balancing_manager( + configuration=provider_model_bundle.configuration, + model_type=provider_model_bundle.model_type_instance.model_type, + model=model, + credentials=self.credentials + ) def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: """ @@ -37,8 +51,10 @@ class ModelInstance: :param model: model name :return: """ - credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=provider_model_bundle.model_type_instance.model_type, + configuration = provider_model_bundle.configuration + model_type = provider_model_bundle.model_type_instance.model_type + credentials = configuration.get_current_credentials( + model_type=model_type, model=model ) @@ -47,6 +63,43 @@ class ModelInstance: return credentials + def _get_load_balancing_manager(self, configuration: ProviderConfiguration, + model_type: ModelType, + model: str, + credentials: dict) -> Optional["LBModelManager"]: + """ + Get load balancing model credentials + :param configuration: provider configuration + :param model_type: model type + :param model: model name + :param credentials: model credentials + :return: + """ + if configuration.model_settings and configuration.using_provider_type == ProviderType.CUSTOM: + current_model_setting = None + # check if model is disabled by admin + for model_setting in configuration.model_settings: + if (model_setting.model_type == model_type + and model_setting.model == model): + current_model_setting = model_setting + break + + # check if load balancing is enabled + if current_model_setting and current_model_setting.load_balancing_configs: + # use load balancing proxy to choose credentials + lb_model_manager = LBModelManager( + tenant_id=configuration.tenant_id, + provider=configuration.provider.provider, + model_type=model_type, + model=model, + load_balancing_configs=current_model_setting.load_balancing_configs, + managed_credentials=credentials if configuration.custom_configuration.provider else None + ) + + return lb_model_manager + + return None + def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ @@ -67,7 +120,8 @@ class ModelInstance: raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, prompt_messages=prompt_messages, @@ -79,6 +133,27 @@ class ModelInstance: callbacks=callbacks ) + def get_llm_num_tokens(self, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for llm + + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + if not isinstance(self.model_type_instance, LargeLanguageModel): + raise Exception("Model type instance is not LargeLanguageModel") + + self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) + return self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + tools=tools + ) + def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ -> TextEmbeddingResult: """ @@ -92,13 +167,32 @@ class ModelInstance: raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, texts=texts, user=user ) + def get_text_embedding_num_tokens(self, texts: list[str]) -> int: + """ + Get number of tokens for text embedding + + :param texts: texts to embed + :return: + """ + if not isinstance(self.model_type_instance, TextEmbeddingModel): + raise Exception("Model type instance is not TextEmbeddingModel") + + self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) + return self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + texts=texts + ) + def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, user: Optional[str] = None) \ @@ -117,7 +211,8 @@ class ModelInstance: raise Exception("Model type instance is not RerankModel") self.model_type_instance = cast(RerankModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, query=query, @@ -140,7 +235,8 @@ class ModelInstance: raise Exception("Model type instance is not ModerationModel") self.model_type_instance = cast(ModerationModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, text=text, @@ -160,7 +256,8 @@ class ModelInstance: raise Exception("Model type instance is not Speech2TextModel") self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, file=file, @@ -183,7 +280,8 @@ class ModelInstance: raise Exception("Model type instance is not TTSModel") self.model_type_instance = cast(TTSModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, content_text=content_text, @@ -193,6 +291,43 @@ class ModelInstance: streaming=streaming ) + def _round_robin_invoke(self, function: callable, *args, **kwargs): + """ + Round-robin invoke + :param function: function to invoke + :param args: function args + :param kwargs: function kwargs + :return: + """ + if not self.load_balancing_manager: + return function(*args, **kwargs) + + last_exception = None + while True: + lb_config = self.load_balancing_manager.fetch_next() + if not lb_config: + if not last_exception: + raise ProviderTokenNotInitError("Model credentials is not initialized.") + else: + raise last_exception + + try: + if 'credentials' in kwargs: + del kwargs['credentials'] + return function(*args, **kwargs, credentials=lb_config.credentials) + except InvokeRateLimitError as e: + # expire in 60 seconds + self.load_balancing_manager.cooldown(lb_config, expire=60) + last_exception = e + continue + except (InvokeAuthorizationError, InvokeConnectionError) as e: + # expire in 10 seconds + self.load_balancing_manager.cooldown(lb_config, expire=10) + last_exception = e + continue + except Exception as e: + raise e + def get_tts_voices(self, language: str) -> list: """ Invoke large language tts model voices @@ -226,6 +361,7 @@ class ModelManager: """ if not provider: return self.get_default_model_instance(tenant_id, model_type) + provider_model_bundle = self._provider_manager.get_provider_model_bundle( tenant_id=tenant_id, provider=provider, @@ -255,3 +391,141 @@ class ModelManager: model_type=model_type, model=default_model_entity.model ) + + +class LBModelManager: + def __init__(self, tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + load_balancing_configs: list[ModelLoadBalancingConfiguration], + managed_credentials: Optional[dict] = None) -> None: + """ + Load balancing model manager + :param load_balancing_configs: all load balancing configurations + :param managed_credentials: credentials if load balancing configuration name is __inherit__ + """ + self._tenant_id = tenant_id + self._provider = provider + self._model_type = model_type + self._model = model + self._load_balancing_configs = load_balancing_configs + + for load_balancing_config in self._load_balancing_configs: + if load_balancing_config.name == "__inherit__": + if not managed_credentials: + # remove __inherit__ if managed credentials is not provided + self._load_balancing_configs.remove(load_balancing_config) + else: + load_balancing_config.credentials = managed_credentials + + def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: + """ + Get next model load balancing config + Strategy: Round Robin + :return: + """ + cache_key = "model_lb_index:{}:{}:{}:{}".format( + self._tenant_id, + self._provider, + self._model_type.value, + self._model + ) + + cooldown_load_balancing_configs = [] + max_index = len(self._load_balancing_configs) + + while True: + current_index = redis_client.incr(cache_key) + if current_index >= 10000000: + current_index = 1 + redis_client.set(cache_key, current_index) + + redis_client.expire(cache_key, 3600) + if current_index > max_index: + current_index = current_index % max_index + + real_index = current_index - 1 + if real_index > max_index: + real_index = 0 + + config = self._load_balancing_configs[real_index] + + if self.in_cooldown(config): + cooldown_load_balancing_configs.append(config) + if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs): + # all configs are in cooldown + return None + + continue + + if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n" + f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" + f"model_type: {self._model_type.value}\nmodel: {self._model}") + + return config + + return None + + def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: + """ + Cooldown model load balancing config + :param config: model load balancing config + :param expire: cooldown time + :return: + """ + cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( + self._tenant_id, + self._provider, + self._model_type.value, + self._model, + config.id + ) + + redis_client.setex(cooldown_cache_key, expire, 'true') + + def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: + """ + Check if model load balancing config is in cooldown + :param config: model load balancing config + :return: + """ + cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( + self._tenant_id, + self._provider, + self._model_type.value, + self._model, + config.id + ) + + return redis_client.exists(cooldown_cache_key) + + @classmethod + def get_config_in_cooldown_and_ttl(cls, tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + config_id: str) -> tuple[bool, int]: + """ + Get model load balancing config is in cooldown and ttl + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param config_id: model load balancing config id + :return: + """ + cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( + tenant_id, + provider, + model_type.value, + model, + config_id + ) + + ttl = redis_client.ttl(cooldown_cache_key) + if ttl == -2: + return False, 0 + + return True, ttl diff --git a/api/core/model_runtime/README.md b/api/core/model_runtime/README.md index d7748a8c3c..b5de7ad412 100644 --- a/api/core/model_runtime/README.md +++ b/api/core/model_runtime/README.md @@ -20,7 +20,7 @@ This module provides the interface for invoking and authenticating various model ![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png) - Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./schema.md). + Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md). - Selectable model list display diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 51af9786fd..bba004a32a 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,4 +1,3 @@ -from abc import ABC from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -14,7 +13,7 @@ _TEXT_COLOR_MAPPING = { } -class Callback(ABC): +class Callback: """ Base class for callbacks. Only for LLM. diff --git a/api/core/model_runtime/docs/en_US/schema.md b/api/core/model_runtime/docs/en_US/schema.md index 68f66330cd..2e55d05b0f 100644 --- a/api/core/model_runtime/docs/en_US/schema.md +++ b/api/core/model_runtime/docs/en_US/schema.md @@ -51,7 +51,7 @@ - `voices` (list) List of available voice.(available for model type `tts`) - `mode` (string) voice model.(available for model type `tts`) - `name` (string) voice model display name.(available for model type `tts`) - - `lanuage` (string) the voice model supports languages.(available for model type `tts`) + - `language` (string) the voice model supports languages.(available for model type `tts`) - `word_limit` (int) Single conversion word limit, paragraphwise by default(available for model type `tts`) - `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`) - `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`) diff --git a/api/core/model_runtime/docs/zh_Hans/schema.md b/api/core/model_runtime/docs/zh_Hans/schema.md index fd672993bb..f40a3f8698 100644 --- a/api/core/model_runtime/docs/zh_Hans/schema.md +++ b/api/core/model_runtime/docs/zh_Hans/schema.md @@ -52,7 +52,7 @@ - `voices` (list) 可选音色列表。 - `mode` (string) 音色模型。(模型类型 `tts` 可用) - `name` (string) 音色模型显示名称。(模型类型 `tts` 可用) - - `lanuage` (string) 音色模型支持语言。(模型类型 `tts` 可用) + - `language` (string) 音色模型支持语言。(模型类型 `tts` 可用) - `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用) - `audio_type` (string) 支持音频文件扩展格式,如:mp3,wav(模型类型 `tts` 可用) - `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用) diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 34a7375493..919e72554c 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -3,8 +3,7 @@ import os from abc import ABC, abstractmethod from typing import Optional -import yaml - +from core.helper.position_helper import get_position_map, sort_by_position_map from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.model_entities import ( @@ -18,7 +17,7 @@ from core.model_runtime.entities.model_entities import ( ) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer -from core.utils.position_helper import get_position_map, sort_by_position_map +from core.tools.utils.yaml_utils import load_yaml_file class AIModel(ABC): @@ -154,8 +153,7 @@ class AIModel(ABC): # traverse all model_schema_yaml_paths for model_schema_yaml_path in model_schema_yaml_paths: # read yaml data from yaml file - with open(model_schema_yaml_path, encoding='utf-8') as f: - yaml_data = yaml.safe_load(f) + yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True) new_parameter_rules = [] for parameter_rule in yaml_data.get('parameter_rules', []): diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 7c839a9672..a893d023c0 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -1,12 +1,11 @@ import os from abc import ABC, abstractmethod -import yaml - +from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source +from core.tools.utils.yaml_utils import load_yaml_file class ModelProvider(ABC): @@ -44,10 +43,7 @@ class ModelProvider(ABC): # read provider schema from yaml file yaml_path = os.path.join(current_path, f'{provider_name}.yaml') - yaml_data = {} - if os.path.exists(yaml_path): - with open(yaml_path, encoding='utf-8') as f: - yaml_data = yaml.safe_load(f) + yaml_data = load_yaml_file(yaml_path, ignore_error=True) try: # yaml_data to entity diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index c06f122984..b483303cad 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -2,7 +2,9 @@ - anthropic - azure_openai - google +- vertex_ai - nvidia +- nvidia_nim - cohere - bedrock - togetherai @@ -26,4 +28,6 @@ - yi - openllm - localai +- volcengine_maas - openai_api_compatible +- deepseek diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml index 0dedb2ef38..cb2af1308a 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml @@ -6,6 +6,7 @@ features: - agent-thought - vision - tool-call + - stream-tool-call model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml index 60e56452eb..101f54c3f8 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml @@ -6,6 +6,7 @@ features: - agent-thought - vision - tool-call + - stream-tool-call model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml index 08c8375d45..daf55553f8 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml @@ -6,6 +6,7 @@ features: - agent-thought - vision - tool-call + - stream-tool-call model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index e0e3514fd1..fbc0b722b1 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -146,7 +146,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format']: + if model_parameters.get('response_format'): stop = stop or [] # chat model self._transform_chat_json_prompts( @@ -324,10 +324,32 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): output_tokens = 0 finish_reason = None index = 0 + + tool_calls: list[AssistantPromptMessage.ToolCall] = [] + for chunk in response: if isinstance(chunk, MessageStartEvent): - return_model = chunk.message.model - input_tokens = chunk.message.usage.input_tokens + if hasattr(chunk, 'content_block'): + content_block = chunk.content_block + if isinstance(content_block, dict): + if content_block.get('type') == 'tool_use': + tool_call = AssistantPromptMessage.ToolCall( + id=content_block.get('id'), + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=content_block.get('name'), + arguments='' + ) + ) + tool_calls.append(tool_call) + elif hasattr(chunk, 'delta'): + delta = chunk.delta + if isinstance(delta, dict) and len(tool_calls) > 0: + if delta.get('type') == 'input_json_delta': + tool_calls[-1].function.arguments += delta.get('partial_json', '') + elif chunk.message: + return_model = chunk.message.model + input_tokens = chunk.message.usage.input_tokens elif isinstance(chunk, MessageDeltaEvent): output_tokens = chunk.usage.output_tokens finish_reason = chunk.delta.stop_reason @@ -335,13 +357,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # transform usage usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) + # transform empty tool call arguments to {} + for tool_call in tool_calls: + if not tool_call.function.arguments: + tool_call.function.arguments = '{}' + yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, message=AssistantPromptMessage( - content='' + content='', + tool_calls=tool_calls ), finish_reason=finish_reason, usage=usage @@ -380,7 +408,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): "max_retries": 1, } - if 'anthropic_api_url' in credentials and credentials['anthropic_api_url']: + if credentials.get('anthropic_api_url'): credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/') credentials_kwargs['base_url'] = credentials['anthropic_api_url'] diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index e81a120fa0..63a0b5c8be 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -49,7 +49,7 @@ LLM_BASE_MODELS = [ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ ModelPropertyKey.MODE: LLMMode.CHAT.value, - ModelPropertyKey.CONTEXT_SIZE: 4096, + ModelPropertyKey.CONTEXT_SIZE: 16385, }, parameter_rules=[ ParameterRule( @@ -68,11 +68,25 @@ LLM_BASE_MODELS = [ name='frequency_penalty', **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), - _get_max_tokens(default=512, min_val=1, max_val=4096) + _get_max_tokens(default=512, min_val=1, max_val=4096), + ParameterRule( + name='response_format', + label=I18nObject( + zh_Hans='回复格式', + en_US='response_format' + ), + type='string', + help=I18nObject( + zh_Hans='指定模型必须输出的格式', + en_US='specifying the format that the model must output' + ), + required=False, + options=['text', 'json_object'] + ), ], pricing=PriceConfig( - input=0.001, - output=0.002, + input=0.0005, + output=0.0015, unit=0.001, currency='USD', ) @@ -482,6 +496,310 @@ LLM_BASE_MODELS = [ ) ) ), + AzureBaseModel( + base_model_name='gpt-4o', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label', + ), + model_type=ModelType.LLM, + features=[ + ModelFeature.AGENT_THOUGHT, + ModelFeature.VISION, + ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, + ], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name='temperature', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], + ), + ParameterRule( + name='top_p', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], + ), + ParameterRule( + name='presence_penalty', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], + ), + ParameterRule( + name='frequency_penalty', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], + ), + _get_max_tokens(default=512, min_val=1, max_val=4096), + ParameterRule( + name='seed', + label=I18nObject( + zh_Hans='种子', + en_US='Seed' + ), + type='int', + help=I18nObject( + zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', + en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + ), + required=False, + precision=2, + min=0, + max=1, + ), + ParameterRule( + name='response_format', + label=I18nObject( + zh_Hans='回复格式', + en_US='response_format' + ), + type='string', + help=I18nObject( + zh_Hans='指定模型必须输出的格式', + en_US='specifying the format that the model must output' + ), + required=False, + options=['text', 'json_object'] + ), + ], + pricing=PriceConfig( + input=5.00, + output=15.00, + unit=0.000001, + currency='USD', + ) + ) + ), + AzureBaseModel( + base_model_name='gpt-4o-2024-05-13', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label', + ), + model_type=ModelType.LLM, + features=[ + ModelFeature.AGENT_THOUGHT, + ModelFeature.VISION, + ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, + ], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name='temperature', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], + ), + ParameterRule( + name='top_p', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], + ), + ParameterRule( + name='presence_penalty', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], + ), + ParameterRule( + name='frequency_penalty', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], + ), + _get_max_tokens(default=512, min_val=1, max_val=4096), + ParameterRule( + name='seed', + label=I18nObject( + zh_Hans='种子', + en_US='Seed' + ), + type='int', + help=I18nObject( + zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', + en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + ), + required=False, + precision=2, + min=0, + max=1, + ), + ParameterRule( + name='response_format', + label=I18nObject( + zh_Hans='回复格式', + en_US='response_format' + ), + type='string', + help=I18nObject( + zh_Hans='指定模型必须输出的格式', + en_US='specifying the format that the model must output' + ), + required=False, + options=['text', 'json_object'] + ), + ], + pricing=PriceConfig( + input=5.00, + output=15.00, + unit=0.000001, + currency='USD', + ) + ) + ), + AzureBaseModel( + base_model_name='gpt-4-turbo', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label', + ), + model_type=ModelType.LLM, + features=[ + ModelFeature.AGENT_THOUGHT, + ModelFeature.VISION, + ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, + ], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name='temperature', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], + ), + ParameterRule( + name='top_p', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], + ), + ParameterRule( + name='presence_penalty', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], + ), + ParameterRule( + name='frequency_penalty', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], + ), + _get_max_tokens(default=512, min_val=1, max_val=4096), + ParameterRule( + name='seed', + label=I18nObject( + zh_Hans='种子', + en_US='Seed' + ), + type='int', + help=I18nObject( + zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', + en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + ), + required=False, + precision=2, + min=0, + max=1, + ), + ParameterRule( + name='response_format', + label=I18nObject( + zh_Hans='回复格式', + en_US='response_format' + ), + type='string', + help=I18nObject( + zh_Hans='指定模型必须输出的格式', + en_US='specifying the format that the model must output' + ), + required=False, + options=['text', 'json_object'] + ), + ], + pricing=PriceConfig( + input=0.01, + output=0.03, + unit=0.001, + currency='USD', + ) + ) + ), + AzureBaseModel( + base_model_name='gpt-4-turbo-2024-04-09', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label', + ), + model_type=ModelType.LLM, + features=[ + ModelFeature.AGENT_THOUGHT, + ModelFeature.VISION, + ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, + ], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name='temperature', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], + ), + ParameterRule( + name='top_p', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], + ), + ParameterRule( + name='presence_penalty', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], + ), + ParameterRule( + name='frequency_penalty', + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], + ), + _get_max_tokens(default=512, min_val=1, max_val=4096), + ParameterRule( + name='seed', + label=I18nObject( + zh_Hans='种子', + en_US='Seed' + ), + type='int', + help=I18nObject( + zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', + en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + ), + required=False, + precision=2, + min=0, + max=1, + ), + ParameterRule( + name='response_format', + label=I18nObject( + zh_Hans='回复格式', + en_US='response_format' + ), + type='string', + help=I18nObject( + zh_Hans='指定模型必须输出的格式', + en_US='specifying the format that the model must output' + ), + required=False, + options=['text', 'json_object'] + ), + ], + pricing=PriceConfig( + input=0.01, + output=0.03, + unit=0.001, + currency='USD', + ) + ) + ), AzureBaseModel( base_model_name='gpt-4-vision-preview', entity=AIModelEntity( diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index 828698acc7..b9f33a8ff2 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -59,6 +59,9 @@ model_credential_schema: - label: en_US: 2023-12-01-preview value: 2023-12-01-preview + - label: + en_US: '2024-02-01' + value: '2024-02-01' placeholder: zh_Hans: 在此选择您的 API 版本 en_US: Select your API Version here @@ -99,6 +102,24 @@ model_credential_schema: show_on: - variable: __model_type value: llm + - label: + en_US: gpt-4o + value: gpt-4o + show_on: + - variable: __model_type + value: llm + - label: + en_US: gpt-4o-2024-05-13 + value: gpt-4o-2024-05-13 + show_on: + - variable: __model_type + value: llm + - label: + en_US: gpt-4-turbo + value: gpt-4-turbo + show_on: + - variable: __model_type + value: llm - label: en_US: gpt-4-turbo-2024-04-09 value: gpt-4-turbo-2024-04-09 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml index e7cf059ea5..04849500dc 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml @@ -6,7 +6,7 @@ features: - agent-thought model_properties: mode: chat - context_size: 4000 + context_size: 32000 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml index fb5d73b068..f91329c77a 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml @@ -6,7 +6,7 @@ features: - agent-thought model_properties: mode: chat - context_size: 192000 + context_size: 32000 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml new file mode 100644 index 0000000000..bf72e82296 --- /dev/null +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml @@ -0,0 +1,45 @@ +model: baichuan3-turbo-128k +label: + en_US: Baichuan3-Turbo-128k +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 8000 + min: 1 + max: 128000 + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + default: 1 + min: 1 + max: 2 + - name: with_search_enhance + label: + zh_Hans: 搜索增强 + en_US: Search Enhance + type: boolean + help: + zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml new file mode 100644 index 0000000000..85882519b8 --- /dev/null +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml @@ -0,0 +1,45 @@ +model: baichuan3-turbo +label: + en_US: Baichuan3-Turbo +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 8000 + min: 1 + max: 32000 + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + default: 1 + min: 1 + max: 2 + - name: with_search_enhance + label: + zh_Hans: 搜索增强 + en_US: Search Enhance + type: boolean + help: + zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml new file mode 100644 index 0000000000..f8c6566081 --- /dev/null +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml @@ -0,0 +1,45 @@ +model: baichuan4 +label: + en_US: Baichuan4 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 8000 + min: 1 + max: 32000 + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + default: 1 + min: 1 + max: 2 + - name: with_search_enhance + label: + zh_Hans: 搜索增强 + en_US: Search Enhance + type: boolean + help: + zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index 639f6a21ce..d7d8b7c91b 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -51,26 +51,29 @@ class BaichuanModel: 'baichuan2-turbo': 'Baichuan2-Turbo', 'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k', 'baichuan2-53b': 'Baichuan2-53B', + 'baichuan3-turbo': 'Baichuan3-Turbo', + 'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k', + 'baichuan4': 'Baichuan4', }[model] def _handle_chat_generate_response(self, response) -> BaichuanMessage: - resp = response.json() - choices = resp.get('choices', []) - message = BaichuanMessage(content='', role='assistant') - for choice in choices: - message.content += choice['message']['content'] - message.role = choice['message']['role'] - if choice['finish_reason']: - message.stop_reason = choice['finish_reason'] + resp = response.json() + choices = resp.get('choices', []) + message = BaichuanMessage(content='', role='assistant') + for choice in choices: + message.content += choice['message']['content'] + message.role = choice['message']['role'] + if choice['finish_reason']: + message.stop_reason = choice['finish_reason'] - if 'usage' in resp: - message.usage = { - 'prompt_tokens': resp['usage']['prompt_tokens'], - 'completion_tokens': resp['usage']['completion_tokens'], - 'total_tokens': resp['usage']['total_tokens'], - } - - return message + if 'usage' in resp: + message.usage = { + 'prompt_tokens': resp['usage']['prompt_tokens'], + 'completion_tokens': resp['usage']['completion_tokens'], + 'total_tokens': resp['usage']['total_tokens'], + } + + return message def _handle_chat_stream_generate_response(self, response) -> Generator: for line in response.iter_lines(): @@ -89,7 +92,7 @@ class BaichuanModel: # save stop reason temporarily stop_reason = '' for choice in choices: - if 'finish_reason' in choice and choice['finish_reason']: + if choice.get('finish_reason'): stop_reason = choice['finish_reason'] if len(choice['delta']['content']) == 0: @@ -110,7 +113,8 @@ class BaichuanModel: def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage], parameters: dict[str, Any]) \ -> dict[str, Any]: - if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': + if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' + or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): prompt_messages = [] for message in messages: if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value: @@ -143,7 +147,8 @@ class BaichuanModel: raise BadRequestError(f"Unknown model: {model}") def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]: - if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': + if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' + or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): # there is no secret key for turbo api return { 'Content-Type': 'application/json', @@ -160,7 +165,8 @@ class BaichuanModel: parameters: dict[str, Any], timeout: int) \ -> Union[Generator, BaichuanMessage]: - if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': + if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' + or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): api_base = 'https://api.baichuan-ai.com/v1/chat/completions' else: raise BadRequestError(f"Unknown model: {model}") diff --git a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml index 7f4d2035cc..732d1c2a93 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml @@ -12,6 +12,7 @@ - meta.llama3-70b-instruct-v1:0 - meta.llama2-13b-chat-v1 - meta.llama2-70b-chat-v1 +- mistral.mistral-small-2402-v1:0 - mistral.mistral-large-2402-v1:0 - mistral.mixtral-8x7b-instruct-v0:1 - mistral.mistral-7b-instruct-v0:2 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml index 73fe5567fc..181b192769 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml @@ -51,7 +51,7 @@ parameter_rules: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. pricing: - input: '0.003' - output: '0.015' + input: '0.00025' + output: '0.00125' unit: '0.001' currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml index cb11df0b60..b782faddba 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml @@ -50,7 +50,7 @@ parameter_rules: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. pricing: - input: '0.00025' - output: '0.00125' + input: '0.003' + output: '0.015' unit: '0.001' currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 81a9ce2f00..1386d680a4 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -358,26 +358,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return message_dict - def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str, + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Get number of tokens for given prompt messages :param model: model name :param credentials: model credentials - :param messages: prompt messages or message string + :param prompt_messages: prompt messages or message string :param tools: tools for tool calling :return:md = genai.GenerativeModel(model) """ prefix = model.split('.')[0] model_name = model.split('.')[1] - if isinstance(messages, str): - prompt = messages + if isinstance(prompt_messages, str): + prompt = prompt_messages else: - prompt = self._convert_messages_to_prompt(messages, prefix, model_name) + prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name) return self._get_num_tokens_by_gpt2(prompt) - def validate_credentials(self, model: str, credentials: dict) -> None: """ diff --git a/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-small-2402-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-small-2402-v1.0.yaml new file mode 100644 index 0000000000..582f4a6d9f --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-small-2402-v1.0.yaml @@ -0,0 +1,27 @@ +model: mistral.mistral-small-2402-v1:0 +label: + en_US: Mistral Small +model_type: llm +model_properties: + mode: completion + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + required: false + default: 0.7 + - name: top_p + use_template: top_p + required: false + default: 1 + - name: max_tokens + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 4096 +pricing: + input: '0.001' + output: '0.03' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml index 5419ff530b..afbea06a3e 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml @@ -1,3 +1,4 @@ - amazon.titan-embed-text-v1 +- amazon.titan-embed-text-v2:0 - cohere.embed-english-v3 - cohere.embed-multilingual-v3 diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml index 6a1cf75be1..e5a55971a1 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml @@ -4,5 +4,5 @@ model_properties: context_size: 8192 pricing: input: '0.0001' - unit: '0.001' + unit: '0.0001' currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v2.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v2.yaml new file mode 100644 index 0000000000..5069efeb10 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v2.yaml @@ -0,0 +1,8 @@ +model: amazon.titan-embed-text-v2:0 +model_type: text-embedding +model_properties: + context_size: 8192 +pricing: + input: '0.00002' + unit: '0.00001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 69436cd737..84b23d4a27 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -59,15 +59,15 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): model_prefix = model.split('.')[0] if model_prefix == "amazon" : - for text in texts: - body = { + for text in texts: + body = { "inputText": text, - } - response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend([response_body.get('embedding')]) - token_usage += response_body.get('inputTextTokenCount') - logger.warning(f'Total Tokens: {token_usage}') - result = TextEmbeddingResult( + } + response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) + embeddings.extend([response_body.get('embedding')]) + token_usage += response_body.get('inputTextTokenCount') + logger.warning(f'Total Tokens: {token_usage}') + result = TextEmbeddingResult( model=model, embeddings=embeddings, usage=self._calc_response_usage( @@ -75,20 +75,20 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): credentials=credentials, tokens=token_usage ) - ) - return result - + ) + return result + if model_prefix == "cohere" : - input_type = 'search_document' if len(texts) > 1 else 'search_query' - for text in texts: - body = { + input_type = 'search_document' if len(texts) > 1 else 'search_query' + for text in texts: + body = { "texts": [text], "input_type": input_type, - } - response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend(response_body.get('embeddings')) - token_usage += len(text) - result = TextEmbeddingResult( + } + response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) + embeddings.extend(response_body.get('embeddings')) + token_usage += len(text) + result = TextEmbeddingResult( model=model, embeddings=embeddings, usage=self._calc_response_usage( @@ -96,9 +96,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): credentials=credentials, tokens=token_usage ) - ) - return result - + ) + return result + #others raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") @@ -183,7 +183,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): ) return usage - + def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: """ Map client error to invoke error @@ -212,9 +212,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): content_type = 'application/json' try: response = bedrock_runtime.invoke_model( - body=json.dumps(body), - modelId=model, - accept=accept, + body=json.dumps(body), + modelId=model, + accept=accept, contentType=content_type ) response_body = json.loads(response.get('body').read().decode('utf-8')) diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index 12dc75aece..e83d08af71 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -1,6 +1,5 @@ import logging from collections.abc import Generator -from os.path import join from typing import Optional, cast from httpx import Timeout @@ -19,6 +18,7 @@ from openai import ( ) from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_message import FunctionCall +from yarl import URL from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -265,7 +265,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": join(credentials['api_base'], 'v1') + "base_url": str(URL(credentials['api_base']) / 'v1') } return client_kwargs diff --git a/api/core/model_runtime/model_providers/cohere/cohere.yaml b/api/core/model_runtime/model_providers/cohere/cohere.yaml index c889a6bfe0..bd40057fe9 100644 --- a/api/core/model_runtime/model_providers/cohere/cohere.yaml +++ b/api/core/model_runtime/model_providers/cohere/cohere.yaml @@ -32,6 +32,15 @@ provider_credential_schema: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key show_on: [ ] + - variable: base_url + label: + zh_Hans: API Base + en_US: API Base + type: text-input + required: false + placeholder: + zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1 + en_US: Enter your API Base, e.g. https://api.cohere.ai/v1 model_credential_schema: model: label: @@ -70,3 +79,12 @@ model_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key + - variable: base_url + label: + zh_Hans: API Base + en_US: API Base + type: text-input + required: false + placeholder: + zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1 + en_US: Enter your API Base, e.g. https://api.cohere.ai/v1 diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 6ace77b813..f9fae5e8ca 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -173,7 +173,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) if stop: model_parameters['end_sequences'] = stop @@ -233,7 +233,8 @@ class CohereLargeLanguageModel(LargeLanguageModel): return response - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse], + def _handle_generate_stream_response(self, model: str, credentials: dict, + response: Iterator[GenerateStreamedResponse], prompt_messages: list[PromptMessage]) -> Generator: """ Handle llm stream response @@ -317,7 +318,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) if stop: model_parameters['stop_sequences'] = stop @@ -636,7 +637,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: number of tokens """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) response = client.tokenize( text=text, diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index 4194f27eb9..d2fdb30c6f 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel): ) # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) response = client.rerank( query=query, documents=docs, diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 8269a41810..0540fb740f 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -141,7 +141,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): return [] # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) response = client.tokenize( text=text, @@ -180,7 +180,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): :return: embeddings and used tokens """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) # call embedding model response = client.embed( diff --git a/api/core/model_runtime/model_providers/deepseek/__init__.py b/api/core/model_runtime/model_providers/deepseek/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/deepseek/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/deepseek/_assets/icon_l_en.svg new file mode 100644 index 0000000000..425494404f --- /dev/null +++ b/api/core/model_runtime/model_providers/deepseek/_assets/icon_l_en.svg @@ -0,0 +1,22 @@ + + + Created with Pixso. + + + + + + + + + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/deepseek/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/deepseek/_assets/icon_s_en.svg new file mode 100644 index 0000000000..aa854a7504 --- /dev/null +++ b/api/core/model_runtime/model_providers/deepseek/_assets/icon_s_en.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/model_runtime/model_providers/deepseek/deepseek.py b/api/core/model_runtime/model_providers/deepseek/deepseek.py new file mode 100644 index 0000000000..d61fd4ddc8 --- /dev/null +++ b/api/core/model_runtime/model_providers/deepseek/deepseek.py @@ -0,0 +1,33 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + + +class DeepSeekProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + # Use `deepseek-chat` model for validate, + # no matter what model you pass in, text completion model or chat model + model_instance.validate_credentials( + model='deepseek-chat', + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/deepseek/deepseek.yaml b/api/core/model_runtime/model_providers/deepseek/deepseek.yaml new file mode 100644 index 0000000000..16abd358d6 --- /dev/null +++ b/api/core/model_runtime/model_providers/deepseek/deepseek.yaml @@ -0,0 +1,41 @@ +provider: deepseek +label: + en_US: deepseek + zh_Hans: 深度求索 +description: + en_US: Models provided by deepseek, such as deepseek-chat、deepseek-coder. + zh_Hans: 深度求索提供的模型,例如 deepseek-chat、deepseek-coder 。 +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#c0cdff" +help: + title: + en_US: Get your API Key from deepseek + zh_Hans: 从深度求索获取 API Key + url: + en_US: https://platform.deepseek.com/api_keys +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: endpoint_url + label: + zh_Hans: 自定义 API endpoint 地址 + en_US: Custom API endpoint URL + type: text-input + required: false + placeholder: + zh_Hans: Base URL, e.g. https://api.deepseek.com/v1 or https://api.deepseek.com + en_US: Base URL, e.g. https://api.deepseek.com/v1 or https://api.deepseek.com diff --git a/api/core/model_runtime/model_providers/deepseek/llm/__init__.py b/api/core/model_runtime/model_providers/deepseek/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/deepseek/llm/_position.yaml b/api/core/model_runtime/model_providers/deepseek/llm/_position.yaml new file mode 100644 index 0000000000..43d03f2ee9 --- /dev/null +++ b/api/core/model_runtime/model_providers/deepseek/llm/_position.yaml @@ -0,0 +1,2 @@ +- deepseek-chat +- deepseek-coder diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml new file mode 100644 index 0000000000..3a5a63fa61 --- /dev/null +++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml @@ -0,0 +1,64 @@ +model: deepseek-chat +label: + zh_Hans: deepseek-chat + en_US: deepseek-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 1 + min: 0.0 + max: 2.0 + help: + zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。 + en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is. + - name: max_tokens + use_template: max_tokens + type: int + default: 4096 + min: 1 + max: 32000 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + type: float + default: 1 + min: 0.01 + max: 1.00 + help: + zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 + en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. + - name: logprobs + help: + zh_Hans: 是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。 + en_US: Whether to return the log probability of the output token. If true, returns the log probability of each output token in the content of message . + type: boolean + - name: top_logprobs + type: int + default: 0 + min: 0 + max: 20 + help: + zh_Hans: 一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token,且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。 + en_US: An integer N between 0 and 20, specifying that each output position returns the top N tokens with output probability, and returns the logarithmic probability of these tokens. When specifying this parameter, logprobs must be true. + - name: frequency_penalty + use_template: frequency_penalty + default: 0 + min: -2.0 + max: 2.0 + help: + zh_Hans: 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。 + en_US: A number between -2.0 and 2.0. If the value is positive, new tokens are penalized based on their frequency of occurrence in existing text, reducing the likelihood that the model will repeat the same content. +pricing: + input: '1' + output: '2' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml new file mode 100644 index 0000000000..8f156be101 --- /dev/null +++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml @@ -0,0 +1,26 @@ +model: deepseek-coder +label: + zh_Hans: deepseek-coder + en_US: deepseek-coder +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 16000 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 32000 + default: 1024 diff --git a/api/core/model_runtime/model_providers/deepseek/llm/llm.py b/api/core/model_runtime/model_providers/deepseek/llm/llm.py new file mode 100644 index 0000000000..bdb3823b60 --- /dev/null +++ b/api/core/model_runtime/model_providers/deepseek/llm/llm.py @@ -0,0 +1,113 @@ +from collections.abc import Generator +from typing import Optional, Union +from urllib.parse import urlparse + +import tiktoken + +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel + + +class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials) + + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + + # refactored from openai model runtime, use cl100k_base for calculate token number + def _num_tokens_from_string(self, model: str, text: str, + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Calculate num tokens for text completion model with tiktoken package. + + :param model: model name + :param text: prompt text + :param tools: tools for tool calling + :return: number of tokens + """ + encoding = tiktoken.get_encoding("cl100k_base") + num_tokens = len(encoding.encode(text)) + + if tools: + num_tokens += self._num_tokens_for_tools(encoding, tools) + + return num_tokens + + # refactored from openai model runtime, use cl100k_base for calculate token number + def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. + + Official documentation: https://github.com/openai/openai-cookbook/blob/ + main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + encoding = tiktoken.get_encoding("cl100k_base") + tokens_per_message = 3 + tokens_per_name = 1 + + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + # Cast str(value) in case the message value is not a string + # This occurs with function messages + # TODO: The current token calculation method for the image type is not implemented, + # which need to download the image and then get the resolution for calculation, + # and will increase the request delay + if isinstance(value, list): + text = '' + for item in value: + if isinstance(item, dict) and item['type'] == 'text': + text += item['text'] + + value = text + + if key == "tool_calls": + for tool_call in value: + for t_key, t_value in tool_call.items(): + num_tokens += len(encoding.encode(t_key)) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += len(encoding.encode(f_key)) + num_tokens += len(encoding.encode(f_value)) + else: + num_tokens += len(encoding.encode(t_key)) + num_tokens += len(encoding.encode(t_value)) + else: + num_tokens += len(encoding.encode(str(value))) + + if key == "name": + num_tokens += tokens_per_name + + # every reply is primed with assistant + num_tokens += 3 + + if tools: + num_tokens += self._num_tokens_for_tools(encoding, tools) + + return num_tokens + + @staticmethod + def _add_custom_parameters(credentials: dict) -> None: + credentials['mode'] = 'chat' + credentials['openai_api_key']=credentials['api_key'] + if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": + credentials['openai_api_base']='https://api.deepseek.com' + else: + parsed_url = urlparse(credentials['endpoint_url']) + credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml new file mode 100644 index 0000000000..24b1c5af8a --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-flash-latest +label: + en_US: Gemini 1.5 Flash +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 27912b13cc..5a674fdeee 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -204,6 +204,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): stream=stream, safety_settings=safety_settings, tools=self._convert_tools_to_glm_tool(tools) if tools else None, + request_options={"timeout": 600} ) if stream: diff --git a/api/core/model_runtime/model_providers/jina/jina.yaml b/api/core/model_runtime/model_providers/jina/jina.yaml index 935546234b..23e18ad75f 100644 --- a/api/core/model_runtime/model_providers/jina/jina.yaml +++ b/api/core/model_runtime/model_providers/jina/jina.yaml @@ -19,6 +19,7 @@ supported_model_types: - rerank configurate_methods: - predefined-model + - customizable-model provider_credential_schema: credential_form_schemas: - variable: api_key @@ -29,3 +30,40 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: base_url + label: + zh_Hans: 服务器 URL + en_US: Base URL + type: text-input + required: true + placeholder: + zh_Hans: Base URL, e.g. https://api.jina.ai/v1 + en_US: Base URL, e.g. https://api.jina.ai/v1 + default: 'https://api.jina.ai/v1' + - variable: context_size + label: + zh_Hans: 上下文大小 + en_US: Context size + placeholder: + zh_Hans: 输入上下文大小 + en_US: Enter context size + required: false + type: text-input + default: '8192' diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py index f644ea6512..de7e038b9f 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -2,6 +2,8 @@ from typing import Optional import httpx +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -38,9 +40,13 @@ class JinaRerankModel(RerankModel): if len(docs) == 0: return RerankResult(model=model, docs=[]) + base_url = credentials.get('base_url', 'https://api.jina.ai/v1') + if base_url.endswith('/'): + base_url = base_url[:-1] + try: response = httpx.post( - "https://api.jina.ai/v1/rerank", + base_url + '/rerank', json={ "model": model, "query": query, @@ -103,3 +109,19 @@ class JinaRerankModel(RerankModel): InvokeAuthorizationError: [httpx.HTTPStatusError], InvokeBadRequestError: [httpx.RequestError] } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')) + } + ) + + return entity \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index da922232c0..74a1aabf7a 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -4,7 +4,8 @@ from typing import Optional from requests import post -from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -23,8 +24,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - api_base: str = 'https://api.jina.ai/v1/embeddings' - models: list[str] = ['jina-embeddings-v2-base-en', 'jina-embeddings-v2-small-en', 'jina-embeddings-v2-base-zh', 'jina-embeddings-v2-base-de'] + api_base: str = 'https://api.jina.ai/v1' def _invoke(self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None) \ @@ -39,11 +39,14 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): :return: embeddings result """ api_key = credentials['api_key'] - if model not in self.models: - raise InvokeBadRequestError('Invalid model name') if not api_key: raise CredentialsValidateFailedError('api_key is required') - url = self.api_base + + base_url = credentials.get('base_url', self.api_base) + if base_url.endswith('/'): + base_url = base_url[:-1] + + url = base_url + '/embeddings' headers = { 'Authorization': 'Bearer ' + api_key, 'Content-Type': 'application/json' @@ -70,7 +73,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): elif response.status_code == 500: raise InvokeServerUnavailableError(msg) else: - raise InvokeError(msg) + raise InvokeBadRequestError(msg) except JSONDecodeError as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") @@ -118,8 +121,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): """ try: self._invoke(model=model, credentials=credentials, texts=['ping']) - except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid api key') + except Exception as e: + raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -137,7 +140,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): InvokeAuthorizationError ], InvokeBadRequestError: [ - KeyError + KeyError, + InvokeBadRequestError ] } @@ -170,3 +174,19 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): ) return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')) + } + ) + + return entity diff --git a/api/core/model_runtime/model_providers/localai/localai.yaml b/api/core/model_runtime/model_providers/localai/localai.yaml index a870914632..864dd7a30c 100644 --- a/api/core/model_runtime/model_providers/localai/localai.yaml +++ b/api/core/model_runtime/model_providers/localai/localai.yaml @@ -15,6 +15,8 @@ help: supported_model_types: - llm - text-embedding + - rerank + - speech2text configurate_methods: - customizable-model model_credential_schema: @@ -57,6 +59,9 @@ model_credential_schema: zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080 en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080 - variable: context_size + show_on: + - variable: __model_type + value: llm label: zh_Hans: 上下文大小 en_US: Context size diff --git a/api/core/model_runtime/model_providers/localai/rerank/__init__.py b/api/core/model_runtime/model_providers/localai/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/localai/rerank/rerank.py b/api/core/model_runtime/model_providers/localai/rerank/rerank.py new file mode 100644 index 0000000000..c8ba9a6c7c --- /dev/null +++ b/api/core/model_runtime/model_providers/localai/rerank/rerank.py @@ -0,0 +1,136 @@ +from json import dumps +from typing import Optional + +import httpx +from requests import post +from yarl import URL + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class LocalaiRerankModel(RerankModel): + """ + LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here. + """ + + def _invoke(self, model: str, credentials: dict, + query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + server_url = credentials['server_url'] + model_name = model + + if not server_url: + raise CredentialsValidateFailedError('server_url is required') + if not model_name: + raise CredentialsValidateFailedError('model_name is required') + + url = server_url + headers = { + 'Authorization': f"Bearer {credentials.get('api_key')}", + 'Content-Type': 'application/json' + } + + data = { + "model": model_name, + "query": query, + "documents": docs, + "top_n": top_n + } + + try: + response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results['results']: + rerank_document = RerankDocument( + index=result['index'], + text=result['document']['text'], + score=result['relevance_score'], + ) + if score_threshold is None or result['relevance_score'] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={} + ) + + return entity diff --git a/api/core/model_runtime/model_providers/localai/speech2text/__init__.py b/api/core/model_runtime/model_providers/localai/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py new file mode 100644 index 0000000000..d7403aff4f --- /dev/null +++ b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py @@ -0,0 +1,101 @@ +from typing import IO, Optional + +from requests import Request, Session +from yarl import URL + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel + + +class LocalAISpeech2text(Speech2TextModel): + """ + Model class for Local AI Text to speech model. + """ + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + + url = str(URL(credentials['server_url']) / "v1/audio/transcriptions") + data = {"model": model} + files = {"file": file} + + session = Session() + request = Request("POST", url, data=data, files=files) + prepared_request = session.prepare_request(request) + response = session.send(prepared_request) + + if 'error' in response.json(): + raise InvokeServerUnavailableError("Empty response") + + return response.json()["text"] + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + audio_file_path = self._get_demo_file_path() + + with open(audio_file_path, 'rb') as audio_file: + self._invoke(model, credentials, audio_file) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError + ], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={}, + parameter_rules=[] + ) + + return entity \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab5-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab5-chat.yaml index e0d81e76cc..2c1f79e2b7 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/abab5-chat.yaml +++ b/api/core/model_runtime/model_providers/minimax/llm/abab5-chat.yaml @@ -18,6 +18,15 @@ parameter_rules: default: 6144 min: 1 max: 6144 + - name: mask_sensitive_info + type: boolean + default: true + label: + zh_Hans: 隐私保护 + en_US: Moderate + help: + zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码 + en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id.. - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml index c0ad1e2fdf..6d29be0d0e 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml +++ b/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml @@ -26,6 +26,15 @@ parameter_rules: default: 6144 min: 1 max: 16384 + - name: mask_sensitive_info + type: boolean + default: true + label: + zh_Hans: 隐私保护 + en_US: Moderate + help: + zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码 + en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id.. - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab5.5s-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab5.5s-chat.yaml index 1ef7046459..aa42bb5739 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/abab5.5s-chat.yaml +++ b/api/core/model_runtime/model_providers/minimax/llm/abab5.5s-chat.yaml @@ -24,6 +24,15 @@ parameter_rules: default: 3072 min: 1 max: 8192 + - name: mask_sensitive_info + type: boolean + default: true + label: + zh_Hans: 隐私保护 + en_US: Moderate + help: + zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码 + en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id.. - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml index 4c487c598e..9188b6b53f 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml +++ b/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml @@ -26,6 +26,15 @@ parameter_rules: default: 2048 min: 1 max: 32768 + - name: mask_sensitive_info + type: boolean + default: true + label: + zh_Hans: 隐私保护 + en_US: Moderate + help: + zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码 + en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id.. - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab6.5-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab6.5-chat.yaml index ead61fb7ca..5d717d5f8c 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/abab6.5-chat.yaml +++ b/api/core/model_runtime/model_providers/minimax/llm/abab6.5-chat.yaml @@ -26,6 +26,15 @@ parameter_rules: default: 2048 min: 1 max: 8192 + - name: mask_sensitive_info + type: boolean + default: true + label: + zh_Hans: 隐私保护 + en_US: Moderate + help: + zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码 + en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id.. - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab6.5s-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab6.5s-chat.yaml index 2cd98abb22..4631fe67e4 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/abab6.5s-chat.yaml +++ b/api/core/model_runtime/model_providers/minimax/llm/abab6.5s-chat.yaml @@ -26,6 +26,15 @@ parameter_rules: default: 2048 min: 1 max: 245760 + - name: mask_sensitive_info + type: boolean + default: true + label: + zh_Hans: 隐私保护 + en_US: Moderate + help: + zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码 + en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id.. - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 81ea2e165e..55747057c9 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -20,16 +20,16 @@ class MinimaxChatCompletionPro: Minimax Chat Completion Pro API, supports function calling however, we do not have enough time and energy to implement it, but the parameters are reserved """ - def generate(self, model: str, api_key: str, group_id: str, + def generate(self, model: str, api_key: str, group_id: str, prompt_messages: list[MinimaxMessage], model_parameters: dict, tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ generate chat completion """ if not api_key or not group_id: raise InvalidAPIKeyError('Invalid API key or group ID') - + url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}' extra_kwargs = {} @@ -42,8 +42,11 @@ class MinimaxChatCompletionPro: if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: extra_kwargs['top_p'] = model_parameters['top_p'] + + if 'mask_sensitive_info' in model_parameters and type(model_parameters['mask_sensitive_info']) == bool: + extra_kwargs['mask_sensitive_info'] = model_parameters['mask_sensitive_info'] - if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']: + if model_parameters.get('plugin_web_search'): extra_kwargs['plugins'] = [ 'plugin_web_search' ] @@ -61,7 +64,7 @@ class MinimaxChatCompletionPro: # check if there is a system message if len(prompt_messages) == 0: raise BadRequestError('At least one message is required') - + if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: bot_setting['content'] = prompt_messages[0].content @@ -70,7 +73,7 @@ class MinimaxChatCompletionPro: # check if there is a user message if len(prompt_messages) == 0: raise BadRequestError('At least one user message is required') - + messages = [message.to_dict() for message in prompt_messages] headers = { @@ -89,21 +92,21 @@ class MinimaxChatCompletionPro: if tools: body['functions'] = tools - body['function_call'] = { 'type': 'auto' } + body['function_call'] = {'type': 'auto'} try: response = post( url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) except Exception as e: raise InternalServerError(e) - + if response.status_code != 200: raise InternalServerError(response.text) - + if stream: return self._handle_stream_chat_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_error(self, code: int, msg: str): if code == 1000 or code == 1001 or code == 1013 or code == 1027: raise InternalServerError(msg) @@ -127,7 +130,7 @@ class MinimaxChatCompletionPro: code = response['base_resp']['status_code'] msg = response['base_resp']['status_msg'] self._handle_error(code, msg) - + message = MinimaxMessage( content=response['reply'], role=MinimaxMessage.Role.ASSISTANT.value @@ -144,7 +147,6 @@ class MinimaxChatCompletionPro: """ handle stream chat generate response """ - function_call_storage = None for line in response.iter_lines(): if not line: continue @@ -158,54 +160,41 @@ class MinimaxChatCompletionPro: msg = data['base_resp']['status_msg'] self._handle_error(code, msg) - if data['reply'] or 'usage' in data and data['usage']: + # final chunk + if data['reply'] or data.get('usage'): total_tokens = data['usage']['total_tokens'] - message = MinimaxMessage( + minimax_message = MinimaxMessage( role=MinimaxMessage.Role.ASSISTANT.value, content='' ) - message.usage = { + minimax_message.usage = { 'prompt_tokens': 0, 'completion_tokens': total_tokens, 'total_tokens': total_tokens } - message.stop_reason = data['choices'][0]['finish_reason'] + minimax_message.stop_reason = data['choices'][0]['finish_reason'] - if function_call_storage: - function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) - function_call_message.function_call = function_call_storage - yield function_call_message + choices = data.get('choices', []) + if len(choices) > 0: + for choice in choices: + message = choice['messages'][0] + # append function_call message + if 'function_call' in message: + function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) + function_call_message.function_call = message['function_call'] + yield function_call_message - yield message + yield minimax_message return + # partial chunk choices = data.get('choices', []) if len(choices) == 0: continue for choice in choices: message = choice['messages'][0] - - if 'function_call' in message: - if not function_call_storage: - function_call_storage = message['function_call'] - if 'arguments' not in function_call_storage or not function_call_storage['arguments']: - function_call_storage['arguments'] = '' - continue - else: - function_call_storage['arguments'] += message['function_call']['arguments'] - continue - else: - if function_call_storage: - message['function_call'] = function_call_storage - function_call_storage = None - - minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) - - if 'function_call' in message: - minimax_message.function_call = message['function_call'] - + # append text message if 'text' in message: - minimax_message.content = message['text'] - - yield minimax_message \ No newline at end of file + minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value) + yield minimax_message diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index cc88d15736..1fab20ebbc 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -34,6 +34,8 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage class MinimaxLargeLanguageModel(LargeLanguageModel): model_apis = { + 'abab6.5s-chat': MinimaxChatCompletionPro, + 'abab6.5-chat': MinimaxChatCompletionPro, 'abab6-chat': MinimaxChatCompletionPro, 'abab5.5s-chat': MinimaxChatCompletionPro, 'abab5.5-chat': MinimaxChatCompletionPro, diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 44a1cf2e84..26c4199d16 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -4,13 +4,13 @@ from typing import Optional from pydantic import BaseModel +from core.helper.module_import_helper import load_single_subclass_from_source +from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator -from core.utils.module_import_helper import load_single_subclass_from_source -from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml b/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml index fc69862722..2401f2a890 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml @@ -1,7 +1,11 @@ - google/gemma-7b - google/codegemma-7b +- google/recurrentgemma-2b - meta/llama2-70b - meta/llama3-8b-instruct - meta/llama3-70b-instruct +- mistralai/mistral-large - mistralai/mixtral-8x7b-instruct-v0.1 +- mistralai/mixtral-8x22b-instruct-v0.1 - fuyu-8b +- snowflake/arctic diff --git a/api/core/model_runtime/model_providers/nvidia/llm/arctic.yaml b/api/core/model_runtime/model_providers/nvidia/llm/arctic.yaml new file mode 100644 index 0000000000..7f53ae58e6 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia/llm/arctic.yaml @@ -0,0 +1,36 @@ +model: snowflake/arctic +label: + zh_Hans: snowflake/arctic + en_US: snowflake/arctic +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4000 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 1024 + default: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index 402ffb2cf2..047bbeda63 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -22,12 +22,16 @@ from core.model_runtime.utils import helper class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): MODEL_SUFFIX_MAP = { 'fuyu-8b': 'vlm/adept/fuyu-8b', + 'mistralai/mistral-large': '', 'mistralai/mixtral-8x7b-instruct-v0.1': '', + 'mistralai/mixtral-8x22b-instruct-v0.1': '', 'google/gemma-7b': '', 'google/codegemma-7b': '', + 'snowflake/arctic':'', 'meta/llama2-70b': '', 'meta/llama3-8b-instruct': '', - 'meta/llama3-70b-instruct': '' + 'meta/llama3-70b-instruct': '', + 'google/recurrentgemma-2b': '' } diff --git a/api/core/model_runtime/model_providers/nvidia/llm/mistral-large.yaml b/api/core/model_runtime/model_providers/nvidia/llm/mistral-large.yaml new file mode 100644 index 0000000000..3e14d22141 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia/llm/mistral-large.yaml @@ -0,0 +1,36 @@ +model: mistralai/mistral-large +label: + zh_Hans: mistralai/mistral-large + en_US: mistralai/mistral-large +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 1024 + default: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/mixtral-8x22b-instruct-v0.1.yaml b/api/core/model_runtime/model_providers/nvidia/llm/mixtral-8x22b-instruct-v0.1.yaml new file mode 100644 index 0000000000..05500c0336 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia/llm/mixtral-8x22b-instruct-v0.1.yaml @@ -0,0 +1,36 @@ +model: mistralai/mixtral-8x22b-instruct-v0.1 +label: + zh_Hans: mistralai/mixtral-8x22b-instruct-v0.1 + en_US: mistralai/mixtral-8x22b-instruct-v0.1 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 64000 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 1024 + default: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/recurrentgemma-2b.yaml b/api/core/model_runtime/model_providers/nvidia/llm/recurrentgemma-2b.yaml new file mode 100644 index 0000000000..73fcce3930 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia/llm/recurrentgemma-2b.yaml @@ -0,0 +1,37 @@ +model: google/recurrentgemma-2b +label: + zh_Hans: google/recurrentgemma-2b + en_US: google/recurrentgemma-2b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 2048 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.2 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 0.7 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 1024 + default: 1024 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: Seed + zh_Hans: 种子 + default: 0 + min: 0 + max: 2147483647 diff --git a/api/core/model_runtime/model_providers/nvidia_nim/__init__.py b/api/core/model_runtime/model_providers/nvidia_nim/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_l_en.png new file mode 100644 index 0000000000..5a7f42e617 Binary files /dev/null and b/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_s_en.svg new file mode 100644 index 0000000000..9fc02f9164 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_s_en.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/model_runtime/model_providers/nvidia_nim/llm/__init__.py b/api/core/model_runtime/model_providers/nvidia_nim/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py new file mode 100644 index 0000000000..f7b849fbe2 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py @@ -0,0 +1,12 @@ +import logging + +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + +logger = logging.getLogger(__name__) + + +class NVIDIANIMProvider(OAIAPICompatLargeLanguageModel): + """ + Model class for NVIDIA NIM large language model. + """ + pass diff --git a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py new file mode 100644 index 0000000000..25ab3e8e20 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py @@ -0,0 +1,11 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class NVIDIANIMProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.yaml b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.yaml new file mode 100644 index 0000000000..0e892665d7 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.yaml @@ -0,0 +1,79 @@ +provider: nvidia_nim +label: + en_US: NVIDIA NIM +description: + en_US: NVIDIA NIM, a set of easy-to-use inference microservices. + zh_Hans: NVIDIA NIM,一组易于使用的模型推理微服务。 +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.png +background: "#EFFDFD" +help: + title: + en_US: Learn more about NVIDIA NIM + zh_Hans: 了解 NVIDIA NIM 更多信息 + url: + en_US: https://www.nvidia.com/en-us/ai/ +supported_model_types: + - llm +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter full model name + zh_Hans: 输入模型全称 + credential_form_schemas: + - variable: endpoint_url + label: + zh_Hans: API endpoint URL + en_US: API endpoint URL + type: text-input + required: true + placeholder: + zh_Hans: Base URL, e.g. http://192.168.1.100:8000/v1 + en_US: Base URL, e.g. http://192.168.1.100:8000/v1 + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens_to_sample + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + show_on: + - variable: __model_type + value: llm + default: '4096' + type: text-input diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index fcb94084a5..dd58a563ab 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -8,7 +8,12 @@ from urllib.parse import urljoin import requests -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -40,7 +45,9 @@ from core.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) logger = logging.getLogger(__name__) @@ -50,11 +57,17 @@ class OllamaLargeLanguageModel(LargeLanguageModel): Model class for Ollama large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -75,11 +88,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -100,10 +118,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if isinstance(first_prompt_message.content, str): text = first_prompt_message.content else: - text = '' + text = "" for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + message_content = cast( + TextPromptMessageContent, message_content + ) text = message_content.data break return self._get_num_tokens_by_gpt2(text) @@ -121,19 +141,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel): model=model, credentials=credentials, prompt_messages=[UserPromptMessage(content="ping")], - model_parameters={ - 'num_predict': 5 - }, - stream=False + model_parameters={"num_predict": 5}, + stream=False, ) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}') + raise CredentialsValidateFailedError( + f"An error occurred during credentials validation: {ex.description}" + ) except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError( + f"An error occurred during credentials validation: {str(ex)}" + ) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -146,76 +175,93 @@ class OllamaLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - endpoint_url = credentials['base_url'] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials["base_url"] + if not endpoint_url.endswith("/"): + endpoint_url += "/" # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'stream': stream - } + data = {"model": model, "stream": stream} - if 'format' in model_parameters: - data['format'] = model_parameters['format'] - del model_parameters['format'] + if "format" in model_parameters: + data["format"] = model_parameters["format"] + del model_parameters["format"] - data['options'] = model_parameters or {} + if "keep_alive" in model_parameters: + data["keep_alive"] = model_parameters["keep_alive"] + del model_parameters["keep_alive"] + + data["options"] = model_parameters or {} if stop: - data['stop'] = "\n".join(stop) + data["stop"] = "\n".join(stop) - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - endpoint_url = urljoin(endpoint_url, 'api/chat') - data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] + endpoint_url = urljoin(endpoint_url, "api/chat") + data["messages"] = [ + self._convert_prompt_message_to_dict(m) for m in prompt_messages + ] else: - endpoint_url = urljoin(endpoint_url, 'api/generate') + endpoint_url = urljoin(endpoint_url, "api/generate") first_prompt_message = prompt_messages[0] if isinstance(first_prompt_message, UserPromptMessage): first_prompt_message = cast(UserPromptMessage, first_prompt_message) if isinstance(first_prompt_message.content, str): - data['prompt'] = first_prompt_message.content + data["prompt"] = first_prompt_message.content else: - text = '' + text = "" images = [] for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + message_content = cast( + TextPromptMessageContent, message_content + ) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) - image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + message_content = cast( + ImagePromptMessageContent, message_content + ) + image_data = re.sub( + r"^data:image\/[a-zA-Z]+;base64,", + "", + message_content.data, + ) images.append(image_data) - data['prompt'] = text - data['images'] = images + data["prompt"] = text + data["images"] = images # send a post request to validate the credentials response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300), - stream=stream + endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream ) response.encoding = "utf-8" if response.status_code != 200: - raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") + raise InvokeError( + f"API request failed with status code {response.status_code}: {response.text}" + ) if stream: - return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) + return self._handle_generate_stream_response( + model, credentials, completion_type, response, prompt_messages + ) - return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) + return self._handle_generate_response( + model, credentials, completion_type, response, prompt_messages + ) - def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode, - response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + completion_type: LLMMode, + response: requests.Response, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm completion response @@ -229,14 +275,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel): response_json = response.json() if completion_type is LLMMode.CHAT: - message = response_json.get('message', {}) - response_content = message.get('content', '') + message = response_json.get("message", {}) + response_content = message.get("content", "") else: - response_content = response_json['response'] + response_content = response_json["response"] assistant_message = AssistantPromptMessage(content=response_content) - if 'prompt_eval_count' in response_json and 'eval_count' in response_json: + if "prompt_eval_count" in response_json and "eval_count" in response_json: # transform usage prompt_tokens = response_json["prompt_eval_count"] completion_tokens = response_json["eval_count"] @@ -246,7 +292,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # transform response result = LLMResult( @@ -258,8 +306,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode, - response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + completion_type: LLMMode, + response: requests.Response, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm completion stream response @@ -270,17 +324,20 @@ class OllamaLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content) completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) return LLMResultChunk( model=model, @@ -289,11 +346,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel): index=index, message=message, finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) - for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'): + for chunk in response.iter_lines(decode_unicode=True, delimiter="\n"): if not chunk: continue @@ -304,7 +361,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) chunk_index += 1 @@ -314,55 +371,57 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if not chunk_json: continue - if 'message' not in chunk_json: - text = '' + if "message" not in chunk_json: + text = "" else: - text = chunk_json.get('message').get('content', '') + text = chunk_json.get("message").get("content", "") else: if not chunk_json: continue # transform assistant message to prompt message - text = chunk_json['response'] + text = chunk_json["response"] - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text - if chunk_json['done']: + if chunk_json["done"]: # calculate num tokens - if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json: + if "prompt_eval_count" in chunk_json and "eval_count" in chunk_json: # transform usage prompt_tokens = chunk_json["prompt_eval_count"] completion_tokens = chunk_json["eval_count"] else: # calculate num tokens - prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content) + prompt_tokens = self._get_num_tokens_by_gpt2( + prompt_messages[0].content + ) completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( - model=chunk_json['model'], + model=chunk_json["model"], prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - finish_reason='stop', - usage=usage - ) + finish_reason="stop", + usage=usage, + ), ) else: yield LLMResultChunk( - model=chunk_json['model'], + model=chunk_json["model"], prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 @@ -376,15 +435,21 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if isinstance(message.content, str): message_dict = {"role": "user", "content": message.content} else: - text = '' + text = "" images = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + message_content = cast( + TextPromptMessageContent, message_content + ) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) - image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + message_content = cast( + ImagePromptMessageContent, message_content + ) + image_data = re.sub( + r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data + ) images.append(image_data) message_dict = {"role": "user", "content": text, "images": images} @@ -414,7 +479,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return num_tokens - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity: """ Get customizable model schema. @@ -425,20 +492,19 @@ class OllamaLargeLanguageModel(LargeLanguageModel): """ extras = {} - if 'vision_support' in credentials and credentials['vision_support'] == 'true': - extras['features'] = [ModelFeature.VISION] + if "vision_support" in credentials and credentials["vision_support"] == "true": + extras["features"] = [ModelFeature.VISION] entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.MODE: credentials.get('mode'), - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)), + ModelPropertyKey.MODE: credentials.get("mode"), + ModelPropertyKey.CONTEXT_SIZE: int( + credentials.get("context_size", 4096) + ), }, parameter_rules=[ ParameterRule( @@ -446,152 +512,195 @@ class OllamaLargeLanguageModel(LargeLanguageModel): use_template=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - help=I18nObject(en_US="The temperature of the model. " - "Increasing the temperature will make the model answer " - "more creatively. (Default: 0.8)"), + help=I18nObject( + en_US="The temperature of the model. " + "Increasing the temperature will make the model answer " + "more creatively. (Default: 0.8)" + ), default=0.1, min=0, - max=2 + max=1, ), ParameterRule( name=DefaultParameterName.TOP_P.value, use_template=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to " - "more diverse text, while a lower value (e.g., 0.5) will generate more " - "focused and conservative text. (Default: 0.9)"), + help=I18nObject( + en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to " + "more diverse text, while a lower value (e.g., 0.5) will generate more " + "focused and conservative text. (Default: 0.9)" + ), default=0.9, min=0, - max=1 + max=1, ), ParameterRule( name="top_k", label=I18nObject(en_US="Top K"), type=ParameterType.INT, - help=I18nObject(en_US="Reduces the probability of generating nonsense. " - "A higher value (e.g. 100) will give more diverse answers, " - "while a lower value (e.g. 10) will be more conservative. (Default: 40)"), + help=I18nObject( + en_US="Reduces the probability of generating nonsense. " + "A higher value (e.g. 100) will give more diverse answers, " + "while a lower value (e.g. 10) will be more conservative. (Default: 40)" + ), min=1, - max=100 + max=100, ), ParameterRule( - name='repeat_penalty', + name="repeat_penalty", label=I18nObject(en_US="Repeat Penalty"), type=ParameterType.FLOAT, - help=I18nObject(en_US="Sets how strongly to penalize repetitions. " - "A higher value (e.g., 1.5) will penalize repetitions more strongly, " - "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"), + help=I18nObject( + en_US="Sets how strongly to penalize repetitions. " + "A higher value (e.g., 1.5) will penalize repetitions more strongly, " + "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)" + ), min=-2, - max=2 + max=2, ), ParameterRule( - name='num_predict', - use_template='max_tokens', + name="num_predict", + use_template="max_tokens", label=I18nObject(en_US="Num Predict"), type=ParameterType.INT, - help=I18nObject(en_US="Maximum number of tokens to predict when generating text. " - "(Default: 128, -1 = infinite generation, -2 = fill context)"), - default=512 if int(credentials.get('max_tokens', 4096)) >= 768 else 128, + help=I18nObject( + en_US="Maximum number of tokens to predict when generating text. " + "(Default: 128, -1 = infinite generation, -2 = fill context)" + ), + default=( + 512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128 + ), min=-2, - max=int(credentials.get('max_tokens', 4096)), + max=int(credentials.get("max_tokens", 4096)), ), ParameterRule( - name='mirostat', + name="mirostat", label=I18nObject(en_US="Mirostat sampling"), type=ParameterType.INT, - help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. " - "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"), + help=I18nObject( + en_US="Enable Mirostat sampling for controlling perplexity. " + "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)" + ), min=0, - max=2 + max=2, ), ParameterRule( - name='mirostat_eta', + name="mirostat_eta", label=I18nObject(en_US="Mirostat Eta"), type=ParameterType.FLOAT, - help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from " - "the generated text. A lower learning rate will result in slower adjustments, " - "while a higher learning rate will make the algorithm more responsive. " - "(Default: 0.1)"), - precision=1 + help=I18nObject( + en_US="Influences how quickly the algorithm responds to feedback from " + "the generated text. A lower learning rate will result in slower adjustments, " + "while a higher learning rate will make the algorithm more responsive. " + "(Default: 0.1)" + ), + precision=1, ), ParameterRule( - name='mirostat_tau', + name="mirostat_tau", label=I18nObject(en_US="Mirostat Tau"), type=ParameterType.FLOAT, - help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. " - "A lower value will result in more focused and coherent text. (Default: 5.0)"), - precision=1 + help=I18nObject( + en_US="Controls the balance between coherence and diversity of the output. " + "A lower value will result in more focused and coherent text. (Default: 5.0)" + ), + precision=1, ), ParameterRule( - name='num_ctx', + name="num_ctx", label=I18nObject(en_US="Size of context window"), type=ParameterType.INT, - help=I18nObject(en_US="Sets the size of the context window used to generate the next token. " - "(Default: 2048)"), + help=I18nObject( + en_US="Sets the size of the context window used to generate the next token. " + "(Default: 2048)" + ), default=2048, - min=1 + min=1, ), ParameterRule( name='num_gpu', - label=I18nObject(en_US="Num GPU"), + label=I18nObject(en_US="GPU Layers"), type=ParameterType.INT, - help=I18nObject(en_US="The number of layers to send to the GPU(s). " - "On macOS it defaults to 1 to enable metal support, 0 to disable."), - min=0, - max=1 + help=I18nObject(en_US="The number of layers to offload to the GPU(s). " + "On macOS it defaults to 1 to enable metal support, 0 to disable." + "As long as a model fits into one gpu it stays in one. " + "It does not set the number of GPU(s). "), + min=-1, + default=1 ), ParameterRule( - name='num_thread', + name="num_thread", label=I18nObject(en_US="Num Thread"), type=ParameterType.INT, - help=I18nObject(en_US="Sets the number of threads to use during computation. " - "By default, Ollama will detect this for optimal performance. " - "It is recommended to set this value to the number of physical CPU cores " - "your system has (as opposed to the logical number of cores)."), + help=I18nObject( + en_US="Sets the number of threads to use during computation. " + "By default, Ollama will detect this for optimal performance. " + "It is recommended to set this value to the number of physical CPU cores " + "your system has (as opposed to the logical number of cores)." + ), min=1, ), ParameterRule( - name='repeat_last_n', + name="repeat_last_n", label=I18nObject(en_US="Repeat last N"), type=ParameterType.INT, - help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. " - "(Default: 64, 0 = disabled, -1 = num_ctx)"), - min=-1 + help=I18nObject( + en_US="Sets how far back for the model to look back to prevent repetition. " + "(Default: 64, 0 = disabled, -1 = num_ctx)" + ), + min=-1, ), ParameterRule( - name='tfs_z', + name="tfs_z", label=I18nObject(en_US="TFS Z"), type=ParameterType.FLOAT, - help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens " - "from the output. A higher value (e.g., 2.0) will reduce the impact more, " - "while a value of 1.0 disables this setting. (default: 1)"), - precision=1 + help=I18nObject( + en_US="Tail free sampling is used to reduce the impact of less probable tokens " + "from the output. A higher value (e.g., 2.0) will reduce the impact more, " + "while a value of 1.0 disables this setting. (default: 1)" + ), + precision=1, ), ParameterRule( - name='seed', + name="seed", label=I18nObject(en_US="Seed"), type=ParameterType.INT, - help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to " - "a specific number will make the model generate the same text for " - "the same prompt. (Default: 0)"), + help=I18nObject( + en_US="Sets the random number seed to use for generation. Setting this to " + "a specific number will make the model generate the same text for " + "the same prompt. (Default: 0)" + ), ), ParameterRule( - name='format', + name="keep_alive", + label=I18nObject(en_US="Keep Alive"), + type=ParameterType.STRING, + help=I18nObject( + en_US="Sets how long the model is kept in memory after generating a response. " + "This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours). " + "A negative number keeps the model loaded indefinitely, and '0' unloads the model immediately after generating a response. " + "Valid time units are 's','m','h'. (Default: 5m)" + ), + ), + ParameterRule( + name="format", label=I18nObject(en_US="Format"), type=ParameterType.STRING, - help=I18nObject(en_US="the format to return a response in." - " Currently the only accepted value is json."), - options=['json'], - ) + help=I18nObject( + en_US="the format to return a response in." + " Currently the only accepted value is json." + ), + options=["json"], + ), ], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - output=Decimal(credentials.get('output_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), ), - **extras + **extras, ) return entity @@ -619,10 +728,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel): ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] + requests.exceptions.ReadTimeout, # Timeout + ], } diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 436461c11e..5772f325e1 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -25,7 +25,7 @@ class _CommonOpenAI: "max_retries": 1, } - if 'openai_api_base' in credentials and credentials['openai_api_base']: + if credentials.get('openai_api_base'): credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/') credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1' diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index 3808d670c3..566055e3f7 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -1,4 +1,6 @@ - gpt-4 +- gpt-4o +- gpt-4o-2024-05-13 - gpt-4-turbo - gpt-4-turbo-2024-04-09 - gpt-4-turbo-preview diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml new file mode 100644 index 0000000000..f0d835cba2 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml @@ -0,0 +1,44 @@ +model: gpt-4o-2024-05-13 +label: + zh_Hans: gpt-4o-2024-05-13 + en_US: gpt-4o-2024-05-13 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '5.00' + output: '15.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml new file mode 100644 index 0000000000..4f141f772f --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml @@ -0,0 +1,44 @@ +model: gpt-4o +label: + zh_Hans: gpt-4o + en_US: gpt-4o +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '5.00' + output: '15.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index b7db39376c..69afabadb3 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -378,6 +378,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if user: extra_model_kwargs['user'] = user + if stream: + extra_model_kwargs['stream_options'] = { + "include_usage": True + } + # text completion model response = client.completions.create( prompt=prompt_messages[0].content, @@ -446,8 +451,24 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :return: llm response chunk generator result """ full_text = '' + prompt_tokens = 0 + completion_tokens = 0 + + final_chunk = LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=''), + ) + ) + for chunk in response: if len(chunk.choices) == 0: + if chunk.usage: + # calculate num tokens + prompt_tokens = chunk.usage.prompt_tokens + completion_tokens = chunk.usage.completion_tokens continue delta = chunk.choices[0] @@ -464,20 +485,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): full_text += text if delta.finish_reason is not None: - # calculate num tokens - if chunk.usage: - # transform usage - prompt_tokens = chunk.usage.prompt_tokens - completion_tokens = chunk.usage.completion_tokens - else: - # calculate num tokens - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) - completion_tokens = self._num_tokens_from_string(model, full_text) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - yield LLMResultChunk( + final_chunk = LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, @@ -485,7 +493,6 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage ) ) else: @@ -499,6 +506,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ) ) + if not prompt_tokens: + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + + if not completion_tokens: + completion_tokens = self._num_tokens_from_string(model, full_text) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + final_chunk.delta.usage = usage + + yield final_chunk + def _chat_generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, @@ -531,6 +551,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): model_parameters["response_format"] = response_format + extra_model_kwargs = {} if tools: @@ -547,6 +568,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if user: extra_model_kwargs['user'] = user + if stream: + extra_model_kwargs['stream_options'] = { + 'include_usage': True + } + # clear illegal prompt messages prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) @@ -630,8 +656,24 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ full_assistant_content = '' delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None + prompt_tokens = 0 + completion_tokens = 0 + final_tool_calls = [] + final_chunk = LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=''), + ) + ) + for chunk in response: if len(chunk.choices) == 0: + if chunk.usage: + # calculate num tokens + prompt_tokens = chunk.usage.prompt_tokens + completion_tokens = chunk.usage.completion_tokens continue delta = chunk.choices[0] @@ -667,6 +709,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) function_call = self._extract_response_function_call(assistant_message_function_call) tool_calls = [function_call] if function_call else [] + if tool_calls: + final_tool_calls.extend(tool_calls) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( @@ -677,19 +721,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): full_assistant_content += delta.delta.content if delta.delta.content else '' if has_finish_reason: - # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools) - - full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=tool_calls - ) - completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - yield LLMResultChunk( + final_chunk = LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, @@ -697,7 +729,6 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage ) ) else: @@ -711,6 +742,22 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ) ) + if not prompt_tokens: + prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools) + + if not completion_tokens: + full_assistant_prompt_message = AssistantPromptMessage( + content=full_assistant_content, + tool_calls=final_tool_calls + ) + completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + final_chunk.delta.usage = usage + + yield final_chunk + def _extract_response_tool_calls(self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ -> list[AssistantPromptMessage.ToolCall]: diff --git a/api/core/model_runtime/model_providers/openai/openai.yaml b/api/core/model_runtime/model_providers/openai/openai.yaml index 3af99e107e..b4dc8fd4f2 100644 --- a/api/core/model_runtime/model_providers/openai/openai.yaml +++ b/api/core/model_runtime/model_providers/openai/openai.yaml @@ -85,5 +85,5 @@ provider_credential_schema: type: text-input required: false placeholder: - zh_Hans: 在此输入您的 API Base - en_US: Enter your API Base + zh_Hans: 在此输入您的 API Base, 如:https://api.openai.com + en_US: Enter your API Base, e.g. https://api.openai.com diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 43258d1e5e..1c3f084207 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -180,7 +180,7 @@ class OpenLLMGenerate: completion_usage += len(token_ids) message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) - if 'finish_reason' in choice and choice['finish_reason']: + if choice.get('finish_reason'): finish_reason = choice['finish_reason'] prompt_token_ids = data['prompt_token_ids'] message.stop_reason = finish_reason diff --git a/api/core/model_runtime/model_providers/vertex_ai/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png new file mode 100644 index 0000000000..9f8f05231a Binary files /dev/null and b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg new file mode 100644 index 0000000000..efc3589c07 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/vertex_ai/_common.py b/api/core/model_runtime/model_providers/vertex_ai/_common.py new file mode 100644 index 0000000000..8f7c859e38 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/_common.py @@ -0,0 +1,15 @@ +from core.model_runtime.errors.invoke import InvokeError + + +class _CommonVertexAi: + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + pass diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-haiku.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-haiku.yaml new file mode 100644 index 0000000000..5613348695 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-haiku.yaml @@ -0,0 +1,56 @@ +model: claude-3-haiku@20240307 +label: + en_US: Claude 3 Haiku +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.00025' + output: '0.00125' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-opus.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-opus.yaml new file mode 100644 index 0000000000..ab084636b5 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-opus.yaml @@ -0,0 +1,56 @@ +model: claude-3-opus@20240229 +label: + en_US: Claude 3 Opus +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.015' + output: '0.075' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-sonnet.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-sonnet.yaml new file mode 100644 index 0000000000..0be0113ffd --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3-sonnet.yaml @@ -0,0 +1,55 @@ +model: claude-3-sonnet@20240229 +label: + en_US: Claude 3 Sonnet +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml new file mode 100644 index 0000000000..da3bc8a64a --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml @@ -0,0 +1,38 @@ +model: gemini-1.0-pro-vision-001 +label: + en_US: Gemini 1.0 Pro Vision +model_type: llm +features: + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 16384 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 2048 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml new file mode 100644 index 0000000000..029fab718c --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml @@ -0,0 +1,38 @@ +model: gemini-1.0-pro-002 +label: + en_US: Gemini 1.0 Pro +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32760 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml new file mode 100644 index 0000000000..72b8410aa1 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml @@ -0,0 +1,38 @@ +model: gemini-1.5-flash-preview-0514 +label: + en_US: Gemini 1.5 Flash +model_type: llm +features: + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml new file mode 100644 index 0000000000..141f61aad6 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-pro-preview-0514 +label: + en_US: Gemini 1.5 Pro +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py new file mode 100644 index 0000000000..0d6dd8d982 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -0,0 +1,728 @@ +import base64 +import json +import logging +from collections.abc import Generator +from typing import Optional, Union, cast + +import google.api_core.exceptions as exceptions +import vertexai.generative_models as glm +from anthropic import AnthropicVertex, Stream +from anthropic.types import ( + ContentBlockDeltaEvent, + Message, + MessageDeltaEvent, + MessageStartEvent, + MessageStopEvent, + MessageStreamEvent, +) +from google.cloud import aiplatform +from google.oauth2 import service_account +from vertexai.generative_models import HarmBlockThreshold, HarmCategory + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + +GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" + + +class VertexAiLargeLanguageModel(LargeLanguageModel): + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # invoke anthropic models via anthropic official SDK + if "claude" in model: + return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user) + # invoke Gemini model + return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + """ + Invoke Anthropic large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :return: full response or stream response chunk generator result + """ + # use Anthropic official SDK references + # - https://github.com/anthropics/anthropic-sdk-python + project_id = credentials["vertex_project_id"] + + if 'opus' in model: + location = 'us-east5' + else: + location = 'us-central1' + + client = AnthropicVertex( + region=location, + project_id=project_id + ) + + extra_model_kwargs = {} + if stop: + extra_model_kwargs['stop_sequences'] = stop + + system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages) + + if system: + extra_model_kwargs['system'] = system + + response = client.messages.create( + model=model, + messages=prompt_message_dicts, + stream=stream, + **model_parameters, + **extra_model_kwargs + ) + + if stream: + return self._handle_claude_stream_response(model, credentials, response, prompt_messages) + + return self._handle_claude_response(model, credentials, response, prompt_messages) + + def _handle_claude_response(self, model: str, credentials: dict, response: Message, + prompt_messages: list[PromptMessage]) -> LLMResult: + """ + Handle llm chat response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: full response chunk generator result + """ + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=response.content[0].text + ) + + # calculate num tokens + if response.usage: + # transform usage + prompt_tokens = response.usage.input_tokens + completion_tokens = response.usage.output_tokens + else: + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + response = LLMResult( + model=response.model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage + ) + + return response + + def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent], + prompt_messages: list[PromptMessage], ) -> Generator: + """ + Handle llm chat stream response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: full response or stream response chunk generator result + """ + + try: + full_assistant_content = '' + return_model = None + input_tokens = 0 + output_tokens = 0 + finish_reason = None + index = 0 + + for chunk in response: + if isinstance(chunk, MessageStartEvent): + return_model = chunk.message.model + input_tokens = chunk.message.usage.input_tokens + elif isinstance(chunk, MessageDeltaEvent): + output_tokens = chunk.usage.output_tokens + finish_reason = chunk.delta.stop_reason + elif isinstance(chunk, MessageStopEvent): + usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) + yield LLMResultChunk( + model=return_model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index + 1, + message=AssistantPromptMessage( + content='' + ), + finish_reason=finish_reason, + usage=usage + ) + ) + elif isinstance(chunk, ContentBlockDeltaEvent): + chunk_text = chunk.delta.text if chunk.delta.text else '' + full_assistant_content += chunk_text + assistant_prompt_message = AssistantPromptMessage( + content=chunk_text if chunk_text else '', + ) + index = chunk.index + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + ) + ) + except Exception as ex: + raise InvokeError(str(ex)) + + def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param prompt_tokens: prompt tokens + :param completion_tokens: completion tokens + :return: usage + """ + # get prompt price info + prompt_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=prompt_tokens, + ) + + # get completion price info + completion_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.OUTPUT, + tokens=completion_tokens + ) + + # transform usage + usage = LLMUsage( + prompt_tokens=prompt_tokens, + prompt_unit_price=prompt_price_info.unit_price, + prompt_price_unit=prompt_price_info.unit, + prompt_price=prompt_price_info.total_amount, + completion_tokens=completion_tokens, + completion_unit_price=completion_price_info.unit_price, + completion_price_unit=completion_price_info.unit, + completion_price=completion_price_info.total_amount, + total_tokens=prompt_tokens + completion_tokens, + total_price=prompt_price_info.total_amount + completion_price_info.total_amount, + currency=prompt_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + def _convert_claude_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: + """ + Convert prompt messages to dict list and system + """ + + system = "" + first_loop = True + for message in prompt_messages: + if isinstance(message, SystemPromptMessage): + message.content=message.content.strip() + if first_loop: + system=message.content + first_loop=False + else: + system+="\n" + system+=message.content + + prompt_message_dicts = [] + for message in prompt_messages: + if not isinstance(message, SystemPromptMessage): + prompt_message_dicts.append(self._convert_claude_prompt_message_to_dict(message)) + + return system, prompt_message_dicts + + def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + if not message_content.data.startswith("data:"): + # fetch image data from url + try: + image_content = requests.get(message_content.data).content + mime_type, _ = mimetypes.guess_type(message_content.data) + base64_data = base64.b64encode(image_content).decode('utf-8') + except Exception as ex: + raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + else: + data_split = message_content.data.split(";base64,") + mime_type = data_split[0].replace("data:", "") + base64_data = data_split[1] + + if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + raise ValueError(f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp") + + sub_message_dict = { + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": base64_data + } + } + sub_messages.append(sub_message_dict) + + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + return message_dict + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return:md = gml.GenerativeModel(model) + """ + prompt = self._convert_messages_to_prompt(prompt_messages) + + return self._get_num_tokens_by_gpt2(prompt) + + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: + """ + Format a list of messages into a full prompt for the Google model + + :param messages: List of PromptMessage to combine. + :return: Combined string with necessary human_prompt and ai_prompt tags. + """ + messages = messages.copy() # don't mutate the original list + + text = "".join( + self._convert_one_message_to_text(message) + for message in messages + ) + + return text.rstrip() + + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: + """ + Convert tool messages to glm tools + + :param tools: tool messages + :return: glm tools + """ + return glm.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name=tool.name, + parameters=glm.Schema( + type=glm.Type.OBJECT, + properties={ + key: { + 'type_': value.get('type', 'string').upper(), + 'description': value.get('description', ''), + 'enum': value.get('enum', []) + } for key, value in tool.parameters.get('properties', {}).items() + }, + required=tool.parameters.get('required', []) + ), + ) for tool in tools + ] + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + + try: + ping_message = SystemPromptMessage(content="ping") + self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) + + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + + def _generate(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: credentials kwargs + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + config_kwargs = model_parameters.copy() + config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + + if stop: + config_kwargs["stop_sequences"] = stop + + service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + project_id = credentials["vertex_project_id"] + location = credentials["vertex_location"] + if service_account_info: + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + else: + aiplatform.init(project=project_id, location=location) + + history = [] + system_instruction = GEMINI_BLOCK_MODE_PROMPT + # hack for gemini-pro-vision, which currently does not support multi-turn chat + if model == "gemini-1.0-pro-vision-001": + last_msg = prompt_messages[-1] + content = self._format_message_to_glm_content(last_msg) + history.append(content) + else: + for msg in prompt_messages: + if isinstance(msg, SystemPromptMessage): + system_instruction = msg.content + else: + content = self._format_message_to_glm_content(msg) + if history and history[-1].role == content.role: + history[-1].parts.extend(content.parts) + else: + history.append(content) + + safety_settings={ + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + } + + google_model = glm.GenerativeModel( + model_name=model, + system_instruction=system_instruction + ) + + response = google_model.generate_content( + contents=history, + generation_config=glm.GenerationConfig( + **config_kwargs + ), + stream=stream, + safety_settings=safety_settings, + tools=self._convert_tools_to_glm_tool(tools) if tools else None + ) + + if stream: + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_generate_response(model, credentials, response, prompt_messages) + + def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse, + prompt_messages: list[PromptMessage]) -> LLMResult: + """ + Handle llm response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response + """ + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=response.candidates[0].content.parts[0].text + ) + + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + result = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage, + ) + + return result + + def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse, + prompt_messages: list[PromptMessage]) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response chunk generator result + """ + index = -1 + for chunk in response: + for part in chunk.candidates[0].content.parts: + assistant_prompt_message = AssistantPromptMessage( + content='' + ) + + if part.text: + assistant_prompt_message.content += part.text + + if part.function_call: + assistant_prompt_message.tool_calls = [ + AssistantPromptMessage.ToolCall( + id=part.function_call.name, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=part.function_call.name, + arguments=json.dumps({ + key: value + for key, value in part.function_call.args.items() + }) + ) + ) + ] + + index += 1 + + if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason: + # transform assistant message to prompt message + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message + ) + ) + else: + + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + finish_reason=chunk.candidates[0].finish_reason, + usage=usage + ) + ) + + def _convert_one_message_to_text(self, message: PromptMessage) -> str: + """ + Convert a single message to a string. + + :param message: PromptMessage to convert. + :return: String representation of the message. + """ + human_prompt = "\n\nuser:" + ai_prompt = "\n\nmodel:" + + content = message.content + if isinstance(content, list): + content = "".join( + c.data for c in content if c.type != PromptMessageContentType.IMAGE + ) + + if isinstance(message, UserPromptMessage): + message_text = f"{human_prompt} {content}" + elif isinstance(message, AssistantPromptMessage): + message_text = f"{ai_prompt} {content}" + elif isinstance(message, SystemPromptMessage): + message_text = f"{human_prompt} {content}" + elif isinstance(message, ToolPromptMessage): + message_text = f"{human_prompt} {content}" + else: + raise ValueError(f"Got unknown type {message}") + + return message_text + + def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: + """ + Format a single message into glm.Content for Google API + + :param message: one PromptMessage + :return: glm Content representation of message + """ + if isinstance(message, UserPromptMessage): + glm_content = glm.Content(role="user", parts=[]) + + if (isinstance(message.content, str)): + glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)]) + else: + parts = [] + for c in message.content: + if c.type == PromptMessageContentType.TEXT: + parts.append(glm.Part.from_text(c.data)) + else: + metadata, data = c.data.split(',', 1) + mime_type = metadata.split(';', 1)[0].split(':')[1] + parts.append(glm.Part.from_data(mime_type=mime_type, data=data)) + glm_content = glm.Content(role="user", parts=parts) + return glm_content + elif isinstance(message, AssistantPromptMessage): + if message.content: + glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)]) + if message.tool_calls: + glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ))]) + return glm_content + elif isinstance(message, ToolPromptMessage): + glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse( + name=message.name, + response={ + "response": message.content + } + ))]) + return glm_content + else: + raise ValueError(f"Got unknown type {message}") + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller + The value is the md = gml.GenerativeModel(model)error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke emd = gml.GenerativeModel(model)rror mapping + """ + return { + InvokeConnectionError: [ + exceptions.RetryError + ], + InvokeServerUnavailableError: [ + exceptions.ServiceUnavailable, + exceptions.InternalServerError, + exceptions.BadGateway, + exceptions.GatewayTimeout, + exceptions.DeadlineExceeded + ], + InvokeRateLimitError: [ + exceptions.ResourceExhausted, + exceptions.TooManyRequests + ], + InvokeAuthorizationError: [ + exceptions.Unauthenticated, + exceptions.PermissionDenied, + exceptions.Unauthenticated, + exceptions.Forbidden + ], + InvokeBadRequestError: [ + exceptions.BadRequest, + exceptions.InvalidArgument, + exceptions.FailedPrecondition, + exceptions.OutOfRange, + exceptions.NotFound, + exceptions.MethodNotAllowed, + exceptions.Conflict, + exceptions.AlreadyExists, + exceptions.Aborted, + exceptions.LengthRequired, + exceptions.PreconditionFailed, + exceptions.RequestRangeNotSatisfiable, + exceptions.Cancelled, + ] + } diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml new file mode 100644 index 0000000000..32db6faf89 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml @@ -0,0 +1,8 @@ +model: text-embedding-004 +model_type: text-embedding +model_properties: + context_size: 2048 +pricing: + input: '0.00013' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml new file mode 100644 index 0000000000..2ec0eea9f2 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml @@ -0,0 +1,8 @@ +model: text-multilingual-embedding-002 +model_type: text-embedding +model_properties: + context_size: 2048 +pricing: + input: '0.00013' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py new file mode 100644 index 0000000000..2404ba5894 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -0,0 +1,197 @@ +import base64 +import json +import time +from decimal import Decimal +from typing import Optional + +import tiktoken +from google.cloud import aiplatform +from google.oauth2 import service_account +from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceConfig, + PriceType, +) +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi + + +class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): + """ + Model class for Vertex AI text embedding model. + """ + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: embeddings result + """ + service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + project_id = credentials["vertex_project_id"] + location = credentials["vertex_location"] + if service_account_info: + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + else: + aiplatform.init(project=project_id, location=location) + + client = VertexTextEmbeddingModel.from_pretrained(model) + + embeddings_batch, embedding_used_tokens = self._embedding_invoke( + client=client, + texts=texts + ) + + # calc usage + usage = self._calc_response_usage( + model=model, + credentials=credentials, + tokens=embedding_used_tokens + ) + + return TextEmbeddingResult( + embeddings=embeddings_batch, + usage=usage, + model=model + ) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + if len(texts) == 0: + return 0 + + try: + enc = tiktoken.encoding_for_model(model) + except KeyError: + enc = tiktoken.get_encoding("cl100k_base") + + total_num_tokens = 0 + for text in texts: + # calculate the number of tokens in the encoded text + tokenized_text = enc.encode(text) + total_num_tokens += len(tokenized_text) + + return total_num_tokens + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + project_id = credentials["vertex_project_id"] + location = credentials["vertex_location"] + if service_account_info: + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + else: + aiplatform.init(project=project_id, location=location) + + client = VertexTextEmbeddingModel.from_pretrained(model) + + # call embedding model + self._embedding_invoke( + model=model, + client=client, + texts=['ping'] + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore + """ + Invoke embedding model + + :param model: model name + :param client: model client + :param texts: texts to embed + :return: embeddings and used tokens + """ + response = client.get_embeddings(texts) + + embeddings = [] + token_usage = 0 + + for i in range(len(response)): + embeddings.append(response[i].values) + token_usage += int(response[i].statistics.token_count) + + return embeddings, token_usage + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.MAX_CHUNKS: 1, + }, + parameter_rules=[], + pricing=PriceConfig( + input=Decimal(credentials.get('input_price', 0)), + unit=Decimal(credentials.get('unit', 0)), + currency=credentials.get('currency', "USD") + ) + ) + + return entity diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py new file mode 100644 index 0000000000..3cbfb088d1 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py @@ -0,0 +1,31 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VertexAiProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + # Use `gemini-1.0-pro-002` model for validate, + model_instance.validate_credentials( + model='gemini-1.0-pro-002', + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml new file mode 100644 index 0000000000..27a4d03fe2 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml @@ -0,0 +1,43 @@ +provider: vertex_ai +label: + en_US: Vertex AI | Google Cloud Platform +description: + en_US: Vertex AI in Google Cloud Platform. +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.png +background: "#FCFDFF" +help: + title: + en_US: Get your Access Details from Google + url: + en_US: https://cloud.google.com/vertex-ai/ +supported_model_types: + - llm + - text-embedding +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: vertex_project_id + label: + en_US: Project ID + type: text-input + required: true + placeholder: + en_US: Enter your Google Cloud Project ID + - variable: vertex_location + label: + en_US: Location + type: text-input + required: true + placeholder: + en_US: Enter your Google Cloud Location + - variable: vertex_service_account_key + label: + en_US: Service Account Key (Leave blank if you use Application Default Credentials) + type: secret-input + required: false + placeholder: + en_US: Enter your Google Cloud Service Account Key in base64 format diff --git a/api/core/model_runtime/model_providers/volcengine_maas/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg new file mode 100644 index 0000000000..616e90916b --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg new file mode 100644 index 0000000000..24b92195bd --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg new file mode 100644 index 0000000000..e6454a89b7 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py new file mode 100644 index 0000000000..c7bf4fde8c --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -0,0 +1,108 @@ +import re +from collections.abc import Callable, Generator +from typing import cast + +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error +from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService + + +class MaaSClient(MaasService): + def __init__(self, host: str, region: str): + self.endpoint_id = None + super().__init__(host, region) + + def set_endpoint_id(self, endpoint_id: str): + self.endpoint_id = endpoint_id + + @classmethod + def from_credential(cls, credentials: dict) -> 'MaaSClient': + host = credentials['api_endpoint_host'] + region = credentials['volc_region'] + ak = credentials['volc_access_key_id'] + sk = credentials['volc_secret_access_key'] + endpoint_id = credentials['endpoint_id'] + + client = cls(host, region) + client.set_endpoint_id(endpoint_id) + client.set_ak(ak) + client.set_sk(sk) + return client + + def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict: + req = { + 'parameters': params, + 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages] + } + if not stream: + return super().chat( + self.endpoint_id, + req, + ) + return super().stream_chat( + self.endpoint_id, + req, + ) + + def embeddings(self, texts: list[str]) -> dict: + req = { + 'input': texts + } + return super().embeddings(self.endpoint_id, req) + + @staticmethod + def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": ChatRole.USER, + "content": message.content} + else: + content = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + raise ValueError( + 'Content object type only support image_url') + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast( + ImagePromptMessageContent, message_content) + image_data = re.sub( + r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + content.append({ + 'type': 'image_url', + 'image_url': { + 'url': '', + 'image_bytes': image_data, + 'detail': message_content.detail, + } + }) + + message_dict = {'role': ChatRole.USER, 'content': content} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {'role': ChatRole.ASSISTANT, + 'content': message.content} + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {'role': ChatRole.SYSTEM, + 'content': message.content} + else: + raise ValueError(f"Got unknown PromptMessage type {message}") + + return message_dict + + @staticmethod + def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: + try: + resp = fn() + except MaasException as e: + raise wrap_error(e) + + return resp diff --git a/api/core/model_runtime/model_providers/volcengine_maas/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/errors.py new file mode 100644 index 0000000000..63397a456e --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/errors.py @@ -0,0 +1,156 @@ +from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException + + +class ClientSDKRequestError(MaasException): + pass + + +class SignatureDoesNotMatch(MaasException): + pass + + +class RequestTimeout(MaasException): + pass + + +class ServiceConnectionTimeout(MaasException): + pass + + +class MissingAuthenticationHeader(MaasException): + pass + + +class AuthenticationHeaderIsInvalid(MaasException): + pass + + +class InternalServiceError(MaasException): + pass + + +class MissingParameter(MaasException): + pass + + +class InvalidParameter(MaasException): + pass + + +class AuthenticationExpire(MaasException): + pass + + +class EndpointIsInvalid(MaasException): + pass + + +class EndpointIsNotEnable(MaasException): + pass + + +class ModelNotSupportStreamMode(MaasException): + pass + + +class ReqTextExistRisk(MaasException): + pass + + +class RespTextExistRisk(MaasException): + pass + + +class EndpointRateLimitExceeded(MaasException): + pass + + +class ServiceConnectionRefused(MaasException): + pass + + +class ServiceConnectionClosed(MaasException): + pass + + +class UnauthorizedUserForEndpoint(MaasException): + pass + + +class InvalidEndpointWithNoURL(MaasException): + pass + + +class EndpointAccountRpmRateLimitExceeded(MaasException): + pass + + +class EndpointAccountTpmRateLimitExceeded(MaasException): + pass + + +class ServiceResourceWaitQueueFull(MaasException): + pass + + +class EndpointIsPending(MaasException): + pass + + +class ServiceNotOpen(MaasException): + pass + + +AuthErrors = { + 'SignatureDoesNotMatch': SignatureDoesNotMatch, + 'MissingAuthenticationHeader': MissingAuthenticationHeader, + 'AuthenticationHeaderIsInvalid': AuthenticationHeaderIsInvalid, + 'AuthenticationExpire': AuthenticationExpire, + 'UnauthorizedUserForEndpoint': UnauthorizedUserForEndpoint, +} + +BadRequestErrors = { + 'MissingParameter': MissingParameter, + 'InvalidParameter': InvalidParameter, + 'EndpointIsInvalid': EndpointIsInvalid, + 'EndpointIsNotEnable': EndpointIsNotEnable, + 'ModelNotSupportStreamMode': ModelNotSupportStreamMode, + 'ReqTextExistRisk': ReqTextExistRisk, + 'RespTextExistRisk': RespTextExistRisk, + 'InvalidEndpointWithNoURL': InvalidEndpointWithNoURL, + 'ServiceNotOpen': ServiceNotOpen, +} + +RateLimitErrors = { + 'EndpointRateLimitExceeded': EndpointRateLimitExceeded, + 'EndpointAccountRpmRateLimitExceeded': EndpointAccountRpmRateLimitExceeded, + 'EndpointAccountTpmRateLimitExceeded': EndpointAccountTpmRateLimitExceeded, +} + +ServerUnavailableErrors = { + 'InternalServiceError': InternalServiceError, + 'EndpointIsPending': EndpointIsPending, + 'ServiceResourceWaitQueueFull': ServiceResourceWaitQueueFull, +} + +ConnectionErrors = { + 'ClientSDKRequestError': ClientSDKRequestError, + 'RequestTimeout': RequestTimeout, + 'ServiceConnectionTimeout': ServiceConnectionTimeout, + 'ServiceConnectionRefused': ServiceConnectionRefused, + 'ServiceConnectionClosed': ServiceConnectionClosed, +} + +ErrorCodeMap = { + **AuthErrors, + **BadRequestErrors, + **RateLimitErrors, + **ServerUnavailableErrors, + **ConnectionErrors, +} + + +def wrap_error(e: MaasException) -> Exception: + if ErrorCodeMap.get(e.code): + return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) + return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py new file mode 100644 index 0000000000..7a36d019e2 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -0,0 +1,284 @@ +import logging +from collections.abc import Generator + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.errors import ( + AuthErrors, + BadRequestErrors, + ConnectionErrors, + RateLimitErrors, + ServerUnavailableErrors, +) +from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs +from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException + +logger = logging.getLogger(__name__) + + +class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate credentials + """ + # ping + client = MaaSClient.from_credential(credentials) + try: + client.chat( + { + 'max_new_tokens': 16, + 'temperature': 0.7, + 'top_p': 0.9, + 'top_k': 15, + }, + [UserPromptMessage(content='ping\nAnswer: ')], + ) + except MaasException as e: + raise CredentialsValidateFailedError(e.message) + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None) -> int: + if len(prompt_messages) == 0: + return 0 + return self._num_tokens_from_messages(prompt_messages) + + def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: + """ + Calculate num tokens. + + :param messages: messages + """ + num_tokens = 0 + messages_dict = [ + MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] + for message in messages_dict: + for key, value in message.items(): + num_tokens += self._get_num_tokens_by_gpt2(str(key)) + num_tokens += self._get_num_tokens_by_gpt2(str(value)) + + return num_tokens + + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + + client = MaaSClient.from_credential(credentials) + + req_params = ModelConfigs.get( + credentials['base_model_name'], {}).get('req_params', {}).copy() + if credentials.get('context_size'): + req_params['max_prompt_tokens'] = credentials.get('context_size') + if credentials.get('max_tokens'): + req_params['max_new_tokens'] = credentials.get('max_tokens') + if model_parameters.get('max_tokens'): + req_params['max_new_tokens'] = model_parameters.get('max_tokens') + if model_parameters.get('temperature'): + req_params['temperature'] = model_parameters.get('temperature') + if model_parameters.get('top_p'): + req_params['top_p'] = model_parameters.get('top_p') + if model_parameters.get('top_k'): + req_params['top_k'] = model_parameters.get('top_k') + if model_parameters.get('presence_penalty'): + req_params['presence_penalty'] = model_parameters.get( + 'presence_penalty') + if model_parameters.get('frequency_penalty'): + req_params['frequency_penalty'] = model_parameters.get( + 'frequency_penalty') + if stop: + req_params['stop'] = stop + + resp = MaaSClient.wrap_exception( + lambda: client.chat(req_params, prompt_messages, stream)) + if not stream: + return self._handle_chat_response(model, credentials, prompt_messages, resp) + return self._handle_stream_chat_response(model, credentials, prompt_messages, resp) + + def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator: + for index, r in enumerate(resp): + choices = r['choices'] + if not choices: + continue + choice = choices[0] + message = choice['message'] + usage = None + if r.get('usage'): + usage = self._calc_usage(model, credentials, r['usage']) + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=[] + ), + usage=usage, + finish_reason=choice.get('finish_reason'), + ), + ) + + def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult: + choices = resp['choices'] + if not choices: + return + choice = choices[0] + message = choice['message'] + + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=[], + ), + usage=self._calc_usage(model, credentials, resp['usage']), + ) + + def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage: + return self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=usage['prompt_tokens'], + completion_tokens=usage['completion_tokens'] + ) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + max_tokens = ModelConfigs.get( + credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens') + if credentials.get('max_tokens'): + max_tokens = int(credentials.get('max_tokens')) + rules = [ + ParameterRule( + name='temperature', + type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', + en_US='Temperature' + ) + ), + ParameterRule( + name='top_p', + type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', + en_US='Top P' + ) + ), + ParameterRule( + name='top_k', + type=ParameterType.INT, + min=1, + default=1, + label=I18nObject( + zh_Hans='Top K', + en_US='Top K' + ) + ), + ParameterRule( + name='presence_penalty', + type=ParameterType.FLOAT, + use_template='presence_penalty', + label={ + 'en_US': 'Presence Penalty', + 'zh_Hans': '存在惩罚', + }, + min=-2.0, + max=2.0, + ), + ParameterRule( + name='frequency_penalty', + type=ParameterType.FLOAT, + use_template='frequency_penalty', + label={ + 'en_US': 'Frequency Penalty', + 'zh_Hans': '频率惩罚', + }, + min=-2.0, + max=2.0, + ), + ParameterRule( + name='max_tokens', + type=ParameterType.INT, + use_template='max_tokens', + min=1, + max=max_tokens, + default=512, + label=I18nObject( + zh_Hans='最大生成长度', + en_US='Max Tokens' + ) + ), + ] + + model_properties = ModelConfigs.get( + credentials['base_model_name'], {}).get('model_properties', {}).copy() + if credentials.get('mode'): + model_properties[ModelPropertyKey.MODE] = credentials.get('mode') + if credentials.get('context_size'): + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( + credentials.get('context_size', 4096)) + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + model_properties=model_properties, + parameter_rules=rules + ) + + return entity + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: ConnectionErrors.values(), + InvokeServerUnavailableError: ServerUnavailableErrors.values(), + InvokeRateLimitError: RateLimitErrors.values(), + InvokeAuthorizationError: AuthErrors.values(), + InvokeBadRequestError: BadRequestErrors.values(), + } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py new file mode 100644 index 0000000000..2e8ff314fc --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -0,0 +1,72 @@ +ModelConfigs = { + 'Doubao-pro-4k': { + 'req_params': { + 'max_prompt_tokens': 4096, + 'max_new_tokens': 4096, + }, + 'model_properties': { + 'context_size': 4096, + 'mode': 'chat', + } + }, + 'Doubao-lite-4k': { + 'req_params': { + 'max_prompt_tokens': 4096, + 'max_new_tokens': 4096, + }, + 'model_properties': { + 'context_size': 4096, + 'mode': 'chat', + } + }, + 'Doubao-pro-32k': { + 'req_params': { + 'max_prompt_tokens': 32768, + 'max_new_tokens': 32768, + }, + 'model_properties': { + 'context_size': 32768, + 'mode': 'chat', + } + }, + 'Doubao-lite-32k': { + 'req_params': { + 'max_prompt_tokens': 32768, + 'max_new_tokens': 32768, + }, + 'model_properties': { + 'context_size': 32768, + 'mode': 'chat', + } + }, + 'Doubao-pro-128k': { + 'req_params': { + 'max_prompt_tokens': 131072, + 'max_new_tokens': 131072, + }, + 'model_properties': { + 'context_size': 131072, + 'mode': 'chat', + } + }, + 'Doubao-lite-128k': { + 'req_params': { + 'max_prompt_tokens': 131072, + 'max_new_tokens': 131072, + }, + 'model_properties': { + 'context_size': 131072, + 'mode': 'chat', + } + }, + 'Skylark2-pro-4k': { + 'req_params': { + 'max_prompt_tokens': 4096, + 'max_new_tokens': 4000, + }, + 'model_properties': { + 'context_size': 4096, + 'mode': 'chat', + } + }, +} diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py new file mode 100644 index 0000000000..569f89e975 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -0,0 +1,9 @@ +ModelConfigs = { + 'Doubao-embedding': { + 'req_params': {}, + 'model_properties': { + 'context_size': 4096, + 'max_chunks': 1, + } + }, +} diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py new file mode 100644 index 0000000000..10b01c0d0d --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -0,0 +1,170 @@ +import time +from decimal import Decimal +from typing import Optional + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceConfig, + PriceType, +) +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.errors import ( + AuthErrors, + BadRequestErrors, + ConnectionErrors, + RateLimitErrors, + ServerUnavailableErrors, +) +from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs +from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException + + +class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for VolcengineMaaS text embedding model. + """ + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + client = MaaSClient.from_credential(credentials) + resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) + + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=resp['usage']['total_tokens']) + + result = TextEmbeddingResult( + model=model, + embeddings=[v['embedding'] for v in resp['data']], + usage=usage + ) + + return result + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + num_tokens = 0 + for text in texts: + # use GPT2Tokenizer to get num tokens + num_tokens += self._get_num_tokens_by_gpt2(text) + return num_tokens + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke(model=model, credentials=credentials, texts=['ping']) + except MaasException as e: + raise CredentialsValidateFailedError(e.message) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: ConnectionErrors.values(), + InvokeServerUnavailableError: ServerUnavailableErrors.values(), + InvokeRateLimitError: RateLimitErrors.values(), + InvokeAuthorizationError: AuthErrors.values(), + InvokeBadRequestError: BadRequestErrors.values(), + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + model_properties = ModelConfigs.get( + credentials['base_model_name'], {}).get('model_properties', {}).copy() + if credentials.get('context_size'): + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( + credentials.get('context_size', 4096)) + if credentials.get('max_chunks'): + model_properties[ModelPropertyKey.MAX_CHUNKS] = int( + credentials.get('max_chunks', 4096)) + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties=model_properties, + parameter_rules=[], + pricing=PriceConfig( + input=Decimal(credentials.get('input_price', 0)), + unit=Decimal(credentials.get('unit', 0)), + currency=credentials.get('currency', "USD") + ) + ) + + return entity + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py new file mode 100644 index 0000000000..64f342f16e --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py @@ -0,0 +1,4 @@ +from .common import ChatRole +from .maas import MaasException, MaasService + +__all__ = ['MaasService', 'ChatRole', 'MaasException'] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py @@ -0,0 +1 @@ + diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py new file mode 100644 index 0000000000..48110f16d7 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py @@ -0,0 +1,144 @@ +# coding : utf-8 +import datetime + +import pytz + +from .util import Util + + +class MetaData: + def __init__(self): + self.algorithm = '' + self.credential_scope = '' + self.signed_headers = '' + self.date = '' + self.region = '' + self.service = '' + + def set_date(self, date): + self.date = date + + def set_service(self, service): + self.service = service + + def set_region(self, region): + self.region = region + + def set_algorithm(self, algorithm): + self.algorithm = algorithm + + def set_credential_scope(self, credential_scope): + self.credential_scope = credential_scope + + def set_signed_headers(self, signed_headers): + self.signed_headers = signed_headers + + +class SignResult: + def __init__(self): + self.xdate = '' + self.xCredential = '' + self.xAlgorithm = '' + self.xSignedHeaders = '' + self.xSignedQueries = '' + self.xSignature = '' + self.xContextSha256 = '' + self.xSecurityToken = '' + + self.authorization = '' + + def __str__(self): + return '\n'.join(['{}:{}'.format(*item) for item in self.__dict__.items()]) + + +class Credentials: + def __init__(self, ak, sk, service, region, session_token=''): + self.ak = ak + self.sk = sk + self.service = service + self.region = region + self.session_token = session_token + + def set_ak(self, ak): + self.ak = ak + + def set_sk(self, sk): + self.sk = sk + + def set_session_token(self, session_token): + self.session_token = session_token + + +class Signer: + @staticmethod + def sign(request, credentials): + if request.path == '': + request.path = '/' + if request.method != 'GET' and not ('Content-Type' in request.headers): + request.headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8' + + format_date = Signer.get_current_format_date() + request.headers['X-Date'] = format_date + if credentials.session_token != '': + request.headers['X-Security-Token'] = credentials.session_token + + md = MetaData() + md.set_algorithm('HMAC-SHA256') + md.set_service(credentials.service) + md.set_region(credentials.region) + md.set_date(format_date[:8]) + + hashed_canon_req = Signer.hashed_canonical_request_v4(request, md) + md.set_credential_scope('/'.join([md.date, md.region, md.service, 'request'])) + + signing_str = '\n'.join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) + signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) + sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) + request.headers['Authorization'] = Signer.build_auth_header_v4(sign, md, credentials) + return + + @staticmethod + def hashed_canonical_request_v4(request, meta): + body_hash = Util.sha256(request.body) + request.headers['X-Content-Sha256'] = body_hash + + signed_headers = dict() + for key in request.headers: + if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'): + signed_headers[key.lower()] = request.headers[key] + + if 'host' in signed_headers: + v = signed_headers['host'] + if v.find(':') != -1: + split = v.split(':') + port = split[1] + if str(port) == '80' or str(port) == '443': + signed_headers['host'] = split[0] + + signed_str = '' + for key in sorted(signed_headers.keys()): + signed_str += key + ':' + signed_headers[key] + '\n' + + meta.set_signed_headers(';'.join(sorted(signed_headers.keys()))) + + canonical_request = '\n'.join( + [request.method, Util.norm_uri(request.path), Util.norm_query(request.query), signed_str, + meta.signed_headers, body_hash]) + + return Util.sha256(canonical_request) + + @staticmethod + def get_signing_secret_key_v4(sk, date, region, service): + date = Util.hmac_sha256(bytes(sk, encoding='utf-8'), date) + region = Util.hmac_sha256(date, region) + service = Util.hmac_sha256(region, service) + return Util.hmac_sha256(service, 'request') + + @staticmethod + def build_auth_header_v4(signature, meta, credentials): + credential = credentials.ak + '/' + meta.credential_scope + return meta.algorithm + ' Credential=' + credential + ', SignedHeaders=' + meta.signed_headers + ', Signature=' + signature + + @staticmethod + def get_current_format_date(): + return datetime.datetime.now(tz=pytz.timezone('UTC')).strftime("%Y%m%dT%H%M%SZ") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py new file mode 100644 index 0000000000..03734ec54f --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py @@ -0,0 +1,207 @@ +import json +from collections import OrderedDict +from urllib.parse import urlencode + +import requests + +from .auth import Signer + +VERSION = 'v1.0.137' + + +class Service: + def __init__(self, service_info, api_info): + self.service_info = service_info + self.api_info = api_info + self.session = requests.session() + + def set_ak(self, ak): + self.service_info.credentials.set_ak(ak) + + def set_sk(self, sk): + self.service_info.credentials.set_sk(sk) + + def set_session_token(self, session_token): + self.service_info.credentials.set_session_token(session_token) + + def set_host(self, host): + self.service_info.host = host + + def set_scheme(self, scheme): + self.service_info.scheme = scheme + + def get(self, api, params, doseq=0): + if not (api in self.api_info): + raise Exception("no such api") + api_info = self.api_info[api] + + r = self.prepare_request(api_info, params, doseq) + + Signer.sign(r, self.service_info.credentials) + + url = r.build(doseq) + resp = self.session.get(url, headers=r.headers, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + if resp.status_code == 200: + return resp.text + else: + raise Exception(resp.text) + + def post(self, api, params, form): + if not (api in self.api_info): + raise Exception("no such api") + api_info = self.api_info[api] + r = self.prepare_request(api_info, params) + r.headers['Content-Type'] = 'application/x-www-form-urlencoded' + r.form = self.merge(api_info.form, form) + r.body = urlencode(r.form, True) + Signer.sign(r, self.service_info.credentials) + + url = r.build() + + resp = self.session.post(url, headers=r.headers, data=r.form, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + if resp.status_code == 200: + return resp.text + else: + raise Exception(resp.text) + + def json(self, api, params, body): + if not (api in self.api_info): + raise Exception("no such api") + api_info = self.api_info[api] + r = self.prepare_request(api_info, params) + r.headers['Content-Type'] = 'application/json' + r.body = body + + Signer.sign(r, self.service_info.credentials) + + url = r.build() + resp = self.session.post(url, headers=r.headers, data=r.body, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + if resp.status_code == 200: + return json.dumps(resp.json()) + else: + raise Exception(resp.text.encode("utf-8")) + + def put(self, url, file_path, headers): + with open(file_path, 'rb') as f: + resp = self.session.put(url, headers=headers, data=f) + if resp.status_code == 200: + return True, resp.text.encode("utf-8") + else: + return False, resp.text.encode("utf-8") + + def put_data(self, url, data, headers): + resp = self.session.put(url, headers=headers, data=data) + if resp.status_code == 200: + return True, resp.text.encode("utf-8") + else: + return False, resp.text.encode("utf-8") + + def prepare_request(self, api_info, params, doseq=0): + for key in params: + if type(params[key]) == int or type(params[key]) == float or type(params[key]) == bool: + params[key] = str(params[key]) + elif type(params[key]) == list: + if not doseq: + params[key] = ','.join(params[key]) + + connection_timeout = self.service_info.connection_timeout + socket_timeout = self.service_info.socket_timeout + + r = Request() + r.set_schema(self.service_info.scheme) + r.set_method(api_info.method) + r.set_connection_timeout(connection_timeout) + r.set_socket_timeout(socket_timeout) + + headers = self.merge(api_info.header, self.service_info.header) + headers['Host'] = self.service_info.host + headers['User-Agent'] = 'volc-sdk-python/' + VERSION + r.set_headers(headers) + + query = self.merge(api_info.query, params) + r.set_query(query) + + r.set_host(self.service_info.host) + r.set_path(api_info.path) + + return r + + @staticmethod + def merge(param1, param2): + od = OrderedDict() + for key in param1: + od[key] = param1[key] + + for key in param2: + od[key] = param2[key] + + return od + + +class Request: + def __init__(self): + self.schema = '' + self.method = '' + self.host = '' + self.path = '' + self.headers = OrderedDict() + self.query = OrderedDict() + self.body = '' + self.form = dict() + self.connection_timeout = 0 + self.socket_timeout = 0 + + def set_schema(self, schema): + self.schema = schema + + def set_method(self, method): + self.method = method + + def set_host(self, host): + self.host = host + + def set_path(self, path): + self.path = path + + def set_headers(self, headers): + self.headers = headers + + def set_query(self, query): + self.query = query + + def set_body(self, body): + self.body = body + + def set_connection_timeout(self, connection_timeout): + self.connection_timeout = connection_timeout + + def set_socket_timeout(self, socket_timeout): + self.socket_timeout = socket_timeout + + def build(self, doseq=0): + return self.schema + '://' + self.host + self.path + '?' + urlencode(self.query, doseq) + + +class ServiceInfo: + def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme='http'): + self.host = host + self.header = header + self.credentials = credentials + self.connection_timeout = connection_timeout + self.socket_timeout = socket_timeout + self.scheme = scheme + + +class ApiInfo: + def __init__(self, method, path, query, form, header): + self.method = method + self.path = path + self.query = query + self.form = form + self.header = header + + def __str__(self): + return 'method: ' + self.method + ', path: ' + self.path diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py new file mode 100644 index 0000000000..7eb5fdfa91 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py @@ -0,0 +1,43 @@ +import hashlib +import hmac +from functools import reduce +from urllib.parse import quote + + +class Util: + @staticmethod + def norm_uri(path): + return quote(path).replace('%2F', '/').replace('+', '%20') + + @staticmethod + def norm_query(params): + query = '' + for key in sorted(params.keys()): + if type(params[key]) == list: + for k in params[key]: + query = query + quote(key, safe='-_.~') + '=' + quote(k, safe='-_.~') + '&' + else: + query = query + quote(key, safe='-_.~') + '=' + quote(params[key], safe='-_.~') + '&' + query = query[:-1] + return query.replace('+', '%20') + + @staticmethod + def hmac_sha256(key, content): + return hmac.new(key, bytes(content, encoding='utf-8'), hashlib.sha256).digest() + + @staticmethod + def sha256(content): + if isinstance(content, str) is True: + return hashlib.sha256(content.encode('utf-8')).hexdigest() + else: + return hashlib.sha256(content).hexdigest() + + @staticmethod + def to_hex(content): + lst = [] + for ch in content: + hv = hex(ch).replace('0x', '') + if len(hv) == 1: + hv = '0' + hv + lst.append(hv) + return reduce(lambda x, y: x + y, lst) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py new file mode 100644 index 0000000000..8b14d026d9 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py @@ -0,0 +1,79 @@ +import json +import random +from datetime import datetime + + +class ChatRole: + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + + +class _Dict(dict): + __setattr__ = dict.__setitem__ + __getattr__ = dict.__getitem__ + + def __missing__(self, key): + return None + + +def dict_to_object(dict_obj): + # 支持嵌套类型 + if isinstance(dict_obj, list): + insts = [] + for i in dict_obj: + insts.append(dict_to_object(i)) + return insts + + if isinstance(dict_obj, dict): + inst = _Dict() + for k, v in dict_obj.items(): + inst[k] = dict_to_object(v) + return inst + + return dict_obj + + +def json_to_object(json_str, req_id=None): + obj = dict_to_object(json.loads(json_str)) + if obj and isinstance(obj, dict) and req_id: + obj["req_id"] = req_id + return obj + + +def gen_req_id(): + return datetime.now().strftime("%Y%m%d%H%M%S") + format( + random.randint(0, 2 ** 64 - 1), "020X" + ) + + +class SSEDecoder: + def __init__(self, source): + self.source = source + + def _read(self): + data = b'' + for chunk in self.source: + for line in chunk.splitlines(True): + data += line + if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')): + yield data + data = b'' + if data: + yield data + + def next(self): + for chunk in self._read(): + for line in chunk.splitlines(): + # skip comment + if line.startswith(b':'): + continue + + if b':' in line: + field, value = line.split(b':', 1) + else: + field, value = line, b'' + + if field == b'data' and len(value) > 0: + yield value diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py new file mode 100644 index 0000000000..3cbe9d9f09 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py @@ -0,0 +1,213 @@ +import copy +import json +from collections.abc import Iterator + +from .base.auth import Credentials, Signer +from .base.service import ApiInfo, Service, ServiceInfo +from .common import SSEDecoder, dict_to_object, gen_req_id, json_to_object + + +class MaasService(Service): + def __init__(self, host, region, connection_timeout=60, socket_timeout=60): + service_info = self.get_service_info( + host, region, connection_timeout, socket_timeout + ) + self._apikey = None + api_info = self.get_api_info() + super().__init__(service_info, api_info) + + def set_apikey(self, apikey): + self._apikey = apikey + + @staticmethod + def get_service_info(host, region, connection_timeout, socket_timeout): + service_info = ServiceInfo( + host, + {"Accept": "application/json"}, + Credentials("", "", "ml_maas", region), + connection_timeout, + socket_timeout, + "https", + ) + return service_info + + @staticmethod + def get_api_info(): + api_info = { + "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}), + "embeddings": ApiInfo( + "POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {} + ), + } + return api_info + + def chat(self, endpoint_id, req): + req["stream"] = False + return self._request(endpoint_id, "chat", req) + + def stream_chat(self, endpoint_id, req): + req_id = gen_req_id() + self._validate("chat", req_id) + apikey = self._apikey + + try: + req["stream"] = True + res = self._call( + endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True + ) + + decoder = SSEDecoder(res) + + def iter_fn(): + for data in decoder.next(): + if data == b"[DONE]": + return + + try: + res = json_to_object( + str(data, encoding="utf-8"), req_id=req_id) + except Exception: + raise + + if res.error is not None and res.error.code_n != 0: + raise MaasException( + res.error.code_n, + res.error.code, + res.error.message, + req_id, + ) + yield res + + return iter_fn() + except MaasException: + raise + except Exception as e: + raise new_client_sdk_request_error(str(e)) + + def embeddings(self, endpoint_id, req): + return self._request(endpoint_id, "embeddings", req) + + def _request(self, endpoint_id, api, req, params={}): + req_id = gen_req_id() + + self._validate(api, req_id) + + apikey = self._apikey + + try: + res = self._call(endpoint_id, api, req_id, params, + json.dumps(req).encode("utf-8"), apikey) + resp = dict_to_object(res.json()) + if resp and isinstance(resp, dict): + resp["req_id"] = req_id + return resp + + except MaasException as e: + raise e + except Exception as e: + raise new_client_sdk_request_error(str(e), req_id) + + def _validate(self, api, req_id): + credentials_exist = ( + self.service_info.credentials is not None and + self.service_info.credentials.sk is not None and + self.service_info.credentials.ak is not None + ) + + if not self._apikey and not credentials_exist: + raise new_client_sdk_request_error("no valid credential", req_id) + + if not (api in self.api_info): + raise new_client_sdk_request_error("no such api", req_id) + + def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False): + api_info = copy.deepcopy(self.api_info[api]) + api_info.path = api_info.path.format(endpoint_id=endpoint_id) + + r = self.prepare_request(api_info, params) + r.headers["x-tt-logid"] = req_id + r.headers["Content-Type"] = "application/json" + r.body = body + + if apikey is None: + Signer.sign(r, self.service_info.credentials) + elif apikey is not None: + r.headers["Authorization"] = "Bearer " + apikey + + url = r.build() + res = self.session.post( + url, + headers=r.headers, + data=r.body, + timeout=( + self.service_info.connection_timeout, + self.service_info.socket_timeout, + ), + stream=stream, + ) + + if res.status_code != 200: + raw = res.text.encode() + res.close() + try: + resp = json_to_object( + str(raw, encoding="utf-8"), req_id=req_id) + except Exception: + raise new_client_sdk_request_error(raw, req_id) + + if resp.error: + raise MaasException( + resp.error.code_n, resp.error.code, resp.error.message, req_id + ) + else: + raise new_client_sdk_request_error(resp, req_id) + + return res + + +class MaasException(Exception): + def __init__(self, code_n, code, message, req_id): + self.code_n = code_n + self.code = code + self.message = message + self.req_id = req_id + + def __str__(self): + return ("Detailed exception information is listed below.\n" + + "req_id: {}\n" + + "code_n: {}\n" + + "code: {}\n" + + "message: {}").format(self.req_id, self.code_n, self.code, self.message) + + +def new_client_sdk_request_error(raw, req_id=""): + return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) + + +class BinaryResponseContent: + def __init__(self, response, request_id) -> None: + self.response = response + self.request_id = request_id + + def stream_to_file( + self, + file: str + ) -> None: + is_first = True + error_bytes = b'' + with open(file, mode="wb") as f: + for data in self.response: + if len(error_bytes) > 0 or (is_first and "\"error\":" in str(data)): + error_bytes += data + else: + f.write(data) + + if len(error_bytes) > 0: + resp = json_to_object( + str(error_bytes, encoding="utf-8"), req_id=self.request_id) + raise MaasException( + resp.error.code_n, resp.error.code, resp.error.message, self.request_id + ) + + def iter_bytes(self) -> Iterator[bytes]: + yield from self.response diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py new file mode 100644 index 0000000000..10f9be2d08 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py @@ -0,0 +1,10 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VolcengineMaaSProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml new file mode 100644 index 0000000000..d7bcbd43f8 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml @@ -0,0 +1,188 @@ +provider: volcengine_maas +label: + en_US: Volcengine +description: + en_US: Volcengine MaaS models. +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg + zh_Hans: icon_l_zh.svg +background: "#F9FAFB" +help: + title: + en_US: Get your Access Key and Secret Access Key from Volcengine Console + url: + en_US: https://console.volcengine.com/iam/keymanage/ +supported_model_types: + - llm + - text-embedding +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your Model Name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: volc_access_key_id + required: true + label: + en_US: Access Key + zh_Hans: Access Key + type: secret-input + placeholder: + en_US: Enter your Access Key + zh_Hans: 输入您的 Access Key + - variable: volc_secret_access_key + required: true + label: + en_US: Secret Access Key + zh_Hans: Secret Access Key + type: secret-input + placeholder: + en_US: Enter your Secret Access Key + zh_Hans: 输入您的 Secret Access Key + - variable: volc_region + required: true + label: + en_US: Volcengine Region + zh_Hans: 火山引擎地区 + type: text-input + default: cn-beijing + placeholder: + en_US: Enter Volcengine Region + zh_Hans: 输入火山引擎地域 + - variable: api_endpoint_host + required: true + label: + en_US: API Endpoint Host + zh_Hans: API Endpoint Host + type: text-input + default: maas-api.ml-platform-cn-beijing.volces.com + placeholder: + en_US: Enter your API Endpoint Host + zh_Hans: 输入 API Endpoint Host + - variable: endpoint_id + required: true + label: + en_US: Endpoint ID + zh_Hans: Endpoint ID + type: text-input + placeholder: + en_US: Enter your Endpoint ID + zh_Hans: 输入您的 Endpoint ID + - variable: base_model_name + label: + en_US: Base Model + zh_Hans: 基础模型 + type: select + required: true + options: + - label: + en_US: Doubao-pro-4k + value: Doubao-pro-4k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-lite-4k + value: Doubao-lite-4k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-pro-32k + value: Doubao-pro-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-lite-32k + value: Doubao-lite-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-pro-128k + value: Doubao-pro-128k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-lite-128k + value: Doubao-lite-128k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Skylark2-pro-4k + value: Skylark2-pro-4k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-embedding + value: Doubao-embedding + show_on: + - variable: __model_type + value: text-embedding + - label: + en_US: Custom + zh_Hans: 自定义 + value: Custom + - variable: mode + required: true + show_on: + - variable: __model_type + value: llm + - variable: base_model_name + value: Custom + label: + zh_Hans: 模型类型 + en_US: Completion Mode + type: select + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select Completion Mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + required: true + show_on: + - variable: base_model_name + value: Custom + label: + zh_Hans: 模型上下文长度 + en_US: Model Context Size + type: text-input + default: '4096' + placeholder: + zh_Hans: 输入您的模型上下文长度 + en_US: Enter your Model Context Size + - variable: max_tokens + required: true + show_on: + - variable: __model_type + value: llm + - variable: base_model_name + value: Custom + label: + zh_Hans: 最大 token 上限 + en_US: Upper Bound for Max Tokens + default: '4096' + type: text-input + placeholder: + zh_Hans: 输入您的模型最大 token 上限 + en_US: Enter your model Upper Bound for Max Tokens diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-128k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-128k.yaml new file mode 100644 index 0000000000..b1b1ba1f69 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-128k.yaml @@ -0,0 +1,37 @@ +model: ernie-3.5-128k +label: + en_US: Ernie-3.5-128K +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + min: 0.1 + max: 1.0 + default: 0.8 + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 2 + max: 4096 + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: response_format + use_template: response_format + - name: disable_search + label: + zh_Hans: 禁用搜索 + en_US: Disable Search + type: boolean + help: + zh_Hans: 禁用模型自行进行外部搜索。 + en_US: Disable the model to perform external search. + required: false diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml index 9487342a1d..1e8cf96440 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml @@ -35,3 +35,4 @@ parameter_rules: zh_Hans: 禁用模型自行进行外部搜索。 en_US: Disable the model to perform external search. required: false +deprecated: true diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-1222.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-1222.yaml index 5dfcd5825b..c43588cfe1 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-1222.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-1222.yaml @@ -35,3 +35,4 @@ parameter_rules: zh_Hans: 禁用模型自行进行外部搜索。 en_US: Disable the model to perform external search. required: false +deprecated: true diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k-0321.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k-0321.yaml new file mode 100644 index 0000000000..52e1dc832d --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k-0321.yaml @@ -0,0 +1,30 @@ +model: ernie-character-8k-0321 +label: + en_US: ERNIE-Character-8K-0321 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0.1 + max: 1.0 + default: 0.95 + - name: top_p + use_template: top_p + min: 0 + max: 1.0 + default: 0.7 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 2 + max: 1024 + - name: presence_penalty + use_template: presence_penalty + default: 1.0 + min: 1.0 + max: 2.0 diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml index 3f09f10d1a..78325c1d64 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml @@ -22,7 +22,7 @@ parameter_rules: use_template: max_tokens default: 1024 min: 2 - max: 1024 + max: 2048 - name: presence_penalty use_template: presence_penalty default: 1.0 diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-128k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-128k.yaml index 3b8885c862..331639624c 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-128k.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-128k.yaml @@ -20,9 +20,9 @@ parameter_rules: default: 0.7 - name: max_tokens use_template: max_tokens - default: 1024 + default: 4096 min: 2 - max: 1024 + max: 4096 - name: presence_penalty use_template: presence_penalty default: 1.0 diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-8k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-8k.yaml index 25d32c9f8a..304c6d1f7e 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-8k.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-speed-8k.yaml @@ -22,7 +22,7 @@ parameter_rules: use_template: max_tokens default: 1024 min: 2 - max: 1024 + max: 2048 - name: presence_penalty use_template: presence_penalty default: 1.0 diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 091337c33d..4646ba384a 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -129,12 +129,14 @@ class ErnieBotModel: 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', + 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', + 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', } function_calling_supports = [ @@ -143,7 +145,9 @@ class ErnieBotModel: 'ernie-3.5-8k', 'ernie-3.5-8k-0205', 'ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205' + 'ernie-3.5-4k-0205', + 'ernie-3.5-128k', + 'ernie-4.0-8k' ] api_key: str = '' diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 602d0b749f..cc3ce17975 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -28,7 +28,10 @@ from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, + PromptMessageContent, + PromptMessageContentType, PromptMessageTool, SystemPromptMessage, ToolPromptMessage, @@ -61,8 +64,8 @@ from core.model_runtime.utils import helper class XinferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -99,7 +102,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): try: if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - + extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], model_uid=credentials['model_uid'] @@ -111,10 +114,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): credentials['completion_type'] = 'completion' else: raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') - + if extra_param.support_function_call: credentials['support_function_call'] = True + if extra_param.support_vision: + credentials['support_vision'] = True + if extra_param.context_length: credentials['context_length'] = extra_param.context_length @@ -135,7 +141,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -155,7 +161,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): text = '' for item in value: if isinstance(item, dict) and item['type'] == 'text': - text += item.text + text += item['text'] value = text @@ -191,7 +197,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens += self._num_tokens_for_tools(tools) return num_tokens - + def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for tool calling @@ -234,7 +240,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens += tokens(required_field) return num_tokens - + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ convert prompt message to text @@ -260,7 +266,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if isinstance(message.content, str): message_dict = {"role": "user", "content": message.content} else: - raise ValueError("User message content must be str") + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": { + "url": message_content.data, + "detail": message_content.detail.value + } + } + sub_messages.append(sub_message_dict) + message_dict = {"role": "user", "content": sub_messages} elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} @@ -277,7 +302,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: @@ -338,8 +363,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') - + + + features = [] + support_function_call = credentials.get('support_function_call', False) + if support_function_call: + features.append(ModelFeature.TOOL_CALL) + + support_vision = credentials.get('support_vision', False) + if support_vision: + features.append(ModelFeature.VISION) + context_length = credentials.get('context_length', 2048) entity = AIModelEntity( @@ -349,10 +384,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - features=[ - ModelFeature.TOOL_CALL - ] if support_function_call else [], - model_properties={ + features=features, + model_properties={ ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length }, @@ -360,22 +393,22 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ) return entity - - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, - tools: list[PromptMessageTool] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ generate text from LLM see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` - + extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` """ if 'server_url' not in credentials: raise CredentialsValidateFailedError('server_url is required in credentials') - + if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] @@ -408,11 +441,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): 'function': helper.dump_model(tool) } for tool in tools ] - + vision = credentials.get('support_vision', False) if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): resp = client.chat.completions.create( model=credentials['model_uid'], - messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], + messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], stream=stream, user=user, **generate_config, @@ -497,7 +530,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") - + assistant_message = resp.choices[0].message # convert tool call to assistant message tool call @@ -527,7 +560,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ) return response - + def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], resp: Iterator[ChatCompletionChunk]) -> Generator: @@ -544,7 +577,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): continue - + # check if there is a tool call in the response function_call = None tool_calls = [] @@ -573,9 +606,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -608,7 +641,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") - + assistant_message = resp.choices[0].text # transform assistant message to prompt message @@ -670,9 +703,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): completion_tokens = self._num_tokens_from_messages( messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True ) - usage = self._calc_response_usage(model=model, credentials=credentials, + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 66dab65804..9a3fc9b193 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -14,13 +14,15 @@ class XinferenceModelExtraParameter: max_tokens: int = 512 context_length: int = 2048 support_function_call: bool = False + support_vision: bool = False - def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], - support_function_call: bool, max_tokens: int, context_length: int) -> None: + def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], + support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability self.support_function_call = support_function_call + self.support_vision = support_vision self.max_tokens = max_tokens self.context_length = context_length @@ -71,7 +73,7 @@ class XinferenceHelper: raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') if response.status_code != 200: raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') - + response_json = response.json() model_format = response_json.get('model_format', 'ggmlv3') @@ -87,17 +89,19 @@ class XinferenceHelper: model_handle_type = 'chat' else: raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') - + support_function_call = 'tools' in model_ability + support_vision = 'vision' in model_ability max_tokens = response_json.get('max_tokens', 512) context_length = response_json.get('context_length', 2048) - + return XinferenceModelExtraParameter( model_format=model_format, model_handle_type=model_handle_type, model_ability=model_ability, support_function_call=support_function_call, + support_vision=support_vision, max_tokens=max_tokens, context_length=context_length ) \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/yi/llm/_position.yaml b/api/core/model_runtime/model_providers/yi/llm/_position.yaml index 12838d670f..e876893b41 100644 --- a/api/core/model_runtime/model_providers/yi/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/yi/llm/_position.yaml @@ -1,3 +1,9 @@ - yi-34b-chat-0205 - yi-34b-chat-200k - yi-vl-plus +- yi-large +- yi-medium +- yi-vision +- yi-medium-200k +- yi-spark +- yi-large-turbo diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-large-turbo.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-large-turbo.yaml new file mode 100644 index 0000000000..1d00eca2ca --- /dev/null +++ b/api/core/model_runtime/model_providers/yi/llm/yi-large-turbo.yaml @@ -0,0 +1,43 @@ +model: yi-large-turbo +label: + zh_Hans: yi-large-turbo + en_US: yi-large-turbo +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 16384 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。 + en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is. + - name: max_tokens + use_template: max_tokens + type: int + default: 1024 + min: 1 + max: 16384 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + type: float + default: 0.9 + min: 0.01 + max: 1.00 + help: + zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 + en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. +pricing: + input: '12' + output: '12' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-large.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-large.yaml new file mode 100644 index 0000000000..347f511280 --- /dev/null +++ b/api/core/model_runtime/model_providers/yi/llm/yi-large.yaml @@ -0,0 +1,43 @@ +model: yi-large +label: + zh_Hans: yi-large + en_US: yi-large +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 16384 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。 + en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is. + - name: max_tokens + use_template: max_tokens + type: int + default: 1024 + min: 1 + max: 16384 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + type: float + default: 0.9 + min: 0.01 + max: 1.00 + help: + zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 + en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. +pricing: + input: '20' + output: '20' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-medium-200k.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-medium-200k.yaml new file mode 100644 index 0000000000..e8ddbcba97 --- /dev/null +++ b/api/core/model_runtime/model_providers/yi/llm/yi-medium-200k.yaml @@ -0,0 +1,43 @@ +model: yi-medium-200k +label: + zh_Hans: yi-medium-200k + en_US: yi-medium-200k +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 204800 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。 + en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is. + - name: max_tokens + use_template: max_tokens + type: int + default: 1024 + min: 1 + max: 204800 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + type: float + default: 0.9 + min: 0.01 + max: 1.00 + help: + zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 + en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. +pricing: + input: '12' + output: '12' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-medium.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-medium.yaml new file mode 100644 index 0000000000..4f0244d1f5 --- /dev/null +++ b/api/core/model_runtime/model_providers/yi/llm/yi-medium.yaml @@ -0,0 +1,43 @@ +model: yi-medium +label: + zh_Hans: yi-medium + en_US: yi-medium +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 16384 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。 + en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is. + - name: max_tokens + use_template: max_tokens + type: int + default: 1024 + min: 1 + max: 16384 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + type: float + default: 0.9 + min: 0.01 + max: 1.00 + help: + zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 + en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. +pricing: + input: '2.5' + output: '2.5' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-spark.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-spark.yaml new file mode 100644 index 0000000000..e28e9fd8c0 --- /dev/null +++ b/api/core/model_runtime/model_providers/yi/llm/yi-spark.yaml @@ -0,0 +1,43 @@ +model: yi-spark +label: + zh_Hans: yi-spark + en_US: yi-spark +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 16384 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。 + en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is. + - name: max_tokens + use_template: max_tokens + type: int + default: 1024 + min: 1 + max: 16384 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + type: float + default: 0.9 + min: 0.01 + max: 1.00 + help: + zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 + en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. +pricing: + input: '1' + output: '1' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-vision.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-vision.yaml new file mode 100644 index 0000000000..bce34f5836 --- /dev/null +++ b/api/core/model_runtime/model_providers/yi/llm/yi-vision.yaml @@ -0,0 +1,44 @@ +model: yi-vision +label: + zh_Hans: yi-vision + en_US: yi-vision +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。 + en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is. + - name: max_tokens + use_template: max_tokens + type: int + default: 1024 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + type: float + default: 0.9 + min: 0.01 + max: 1.00 + help: + zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 + en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. +pricing: + input: '6' + output: '6' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/yi/yi.yaml b/api/core/model_runtime/model_providers/yi/yi.yaml index a8c0d857b6..de741afb10 100644 --- a/api/core/model_runtime/model_providers/yi/yi.yaml +++ b/api/core/model_runtime/model_providers/yi/yi.yaml @@ -33,7 +33,7 @@ provider_credential_schema: - variable: endpoint_url label: zh_Hans: 自定义 API endpoint 地址 - en_US: CUstom API endpoint URL + en_US: Custom API endpoint URL type: text-input required: false placeholder: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 29b516ac02..22420fea2c 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -2,6 +2,7 @@ from typing import Optional, Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file.file_obj import FileVar +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -80,29 +81,35 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages = [] - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + if prompt_template.edition_type == 'basic' or not prompt_template.edition_type: + prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - if memory and memory_config: - role_prefix = memory_config.role_prefix - prompt_inputs = self._set_histories_variable( - memory=memory, - memory_config=memory_config, - raw_prompt=raw_prompt, - role_prefix=role_prefix, - prompt_template=prompt_template, - prompt_inputs=prompt_inputs, - model_config=model_config + if memory and memory_config: + role_prefix = memory_config.role_prefix + prompt_inputs = self._set_histories_variable( + memory=memory, + memory_config=memory_config, + raw_prompt=raw_prompt, + role_prefix=role_prefix, + prompt_template=prompt_template, + prompt_inputs=prompt_inputs, + model_config=model_config + ) + + if query: + prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) + + prompt = prompt_template.format( + prompt_inputs ) + else: + prompt = raw_prompt + prompt_inputs = inputs - if query: - prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) - - prompt = prompt_template.format( - prompt_inputs - ) + prompt = Jinja2Formatter.format(prompt, prompt_inputs) if files: prompt_message_contents = [TextPromptMessageContent(data=prompt)] @@ -135,14 +142,22 @@ class AdvancedPromptTransform(PromptTransform): for prompt_item in raw_prompt_list: raw_prompt = prompt_item.text - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + if prompt_item.edition_type == 'basic' or not prompt_item.edition_type: + prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - prompt = prompt_template.format( - prompt_inputs - ) + prompt = prompt_template.format( + prompt_inputs + ) + elif prompt_item.edition_type == 'jinja2': + prompt = raw_prompt + prompt_inputs = inputs + + prompt = Jinja2Formatter.format(prompt, prompt_inputs) + else: + raise ValueError(f'Invalid edition type: {prompt_item.edition_type}') if prompt_item.role == PromptMessageRole.USER: prompt_messages.append(UserPromptMessage(content=prompt)) diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py new file mode 100644 index 0000000000..af0075ea91 --- /dev/null +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -0,0 +1,82 @@ +from typing import Optional, cast + +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.prompt_transform import PromptTransform + + +class AgentHistoryPromptTransform(PromptTransform): + """ + History Prompt Transform for Agent App + """ + def __init__(self, + model_config: ModelConfigWithCredentialsEntity, + prompt_messages: list[PromptMessage], + history_messages: list[PromptMessage], + memory: Optional[TokenBufferMemory] = None, + ): + self.model_config = model_config + self.prompt_messages = prompt_messages + self.history_messages = history_messages + self.memory = memory + + def get_prompt(self) -> list[PromptMessage]: + prompt_messages = [] + num_system = 0 + for prompt_message in self.history_messages: + if isinstance(prompt_message, SystemPromptMessage): + prompt_messages.append(prompt_message) + num_system += 1 + + if not self.memory: + return prompt_messages + + max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config) + + model_type_instance = self.model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + curr_message_tokens = model_type_instance.get_num_tokens( + self.memory.model_instance.model, + self.memory.model_instance.credentials, + self.history_messages + ) + if curr_message_tokens <= max_token_limit: + return self.history_messages + + # number of prompt has been appended in current message + num_prompt = 0 + # append prompt messages in desc order + for prompt_message in self.history_messages[::-1]: + if isinstance(prompt_message, SystemPromptMessage): + continue + prompt_messages.append(prompt_message) + num_prompt += 1 + # a message is start with UserPromptMessage + if isinstance(prompt_message, UserPromptMessage): + curr_message_tokens = model_type_instance.get_num_tokens( + self.memory.model_instance.model, + self.memory.model_instance.credentials, + prompt_messages + ) + # if current message token is overflow, drop all the prompts in current message and break + if curr_message_tokens > max_token_limit: + prompt_messages = prompt_messages[:-num_prompt] + break + num_prompt = 0 + # return prompt messages in asc order + message_prompts = prompt_messages[num_system:] + message_prompts.reverse() + + # merge system and message prompt + prompt_messages = prompt_messages[:num_system] + prompt_messages.extend(message_prompts) + return prompt_messages diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 2be00bdf0e..23a8602bea 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel @@ -11,6 +11,7 @@ class ChatModelMessage(BaseModel): """ text: str role: PromptMessageRole + edition_type: Optional[Literal['basic', 'jinja2']] class CompletionModelPromptTemplate(BaseModel): @@ -18,6 +19,7 @@ class CompletionModelPromptTemplate(BaseModel): Completion Model Prompt Template. """ text: str + edition_type: Optional[Literal['basic', 'jinja2']] class MemoryConfig(BaseModel): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 9bf2ae090f..d8e2d2f76d 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,10 +1,10 @@ -from typing import Optional, cast +from typing import Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig @@ -25,12 +25,12 @@ class PromptTransform: model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model + ) - curr_message_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, + curr_message_tokens = model_instance.get_llm_num_tokens( prompt_messages ) diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 5fceeb3595..befdceeda5 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,6 +1,7 @@ from typing import cast from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -21,13 +22,25 @@ class PromptMessageUtil: """ prompts = [] if model_mode == ModelMode.CHAT.value: + tool_calls = [] for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: role = 'user' elif prompt_message.role == PromptMessageRole.ASSISTANT: role = 'assistant' + if isinstance(prompt_message, AssistantPromptMessage): + tool_calls = [{ + 'id': tool_call.id, + 'type': 'function', + 'function': { + 'name': tool_call.function.name, + 'arguments': tool_call.function.arguments, + } + } for tool_call in prompt_message.tool_calls] elif prompt_message.role == PromptMessageRole.SYSTEM: role = 'system' + elif prompt_message.role == PromptMessageRole.TOOL: + role = 'tool' else: continue @@ -48,11 +61,16 @@ class PromptMessageUtil: else: text = prompt_message.content - prompts.append({ + prompt = { "role": role, "text": text, "files": files - }) + } + + if tool_calls: + prompt['tool_calls'] = tool_calls + + prompts.append(prompt) else: prompt_message = prompt_messages[0] text = '' diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 0db84d3b69..c9447a79df 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -11,6 +11,8 @@ from core.entities.provider_entities import ( CustomConfiguration, CustomModelConfiguration, CustomProviderConfiguration, + ModelLoadBalancingConfiguration, + ModelSettings, QuotaConfiguration, SystemConfiguration, ) @@ -25,14 +27,18 @@ from core.model_runtime.entities.provider_entities import ( from core.model_runtime.model_providers import model_provider_factory from extensions import ext_hosting_provider from extensions.ext_database import db +from extensions.ext_redis import redis_client from models.provider import ( + LoadBalancingModelConfig, Provider, ProviderModel, + ProviderModelSetting, ProviderQuotaType, ProviderType, TenantDefaultModel, TenantPreferredModelProvider, ) +from services.feature_service import FeatureService class ProviderManager: @@ -98,6 +104,13 @@ class ProviderManager: # Get All preferred provider types of the workspace provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) + # Get All provider model settings + provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) + + # Get All load balancing configs + provider_name_to_provider_load_balancing_model_configs_dict \ + = self._get_all_provider_load_balancing_configs(tenant_id) + provider_configurations = ProviderConfigurations( tenant_id=tenant_id ) @@ -105,14 +118,8 @@ class ProviderManager: # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: provider_name = provider_entity.provider - - provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider) - if not provider_records: - provider_records = [] - - provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider) - if not provider_model_records: - provider_model_records = [] + provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) + provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) # Convert to custom configuration custom_configuration = self._to_custom_configuration( @@ -134,38 +141,38 @@ class ProviderManager: if preferred_provider_type_record: preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) + elif custom_configuration.provider or custom_configuration.models: + preferred_provider_type = ProviderType.CUSTOM + elif system_configuration.enabled: + preferred_provider_type = ProviderType.SYSTEM else: - if custom_configuration.provider or custom_configuration.models: - preferred_provider_type = ProviderType.CUSTOM - elif system_configuration.enabled: - preferred_provider_type = ProviderType.SYSTEM - else: - preferred_provider_type = ProviderType.CUSTOM + preferred_provider_type = ProviderType.CUSTOM using_provider_type = preferred_provider_type + has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations) + if preferred_provider_type == ProviderType.SYSTEM: - if not system_configuration.enabled: + if not system_configuration.enabled or not has_valid_quota: using_provider_type = ProviderType.CUSTOM - has_valid_quota = False - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.is_valid: - has_valid_quota = True - break - - if not has_valid_quota: - using_provider_type = ProviderType.CUSTOM else: if not custom_configuration.provider and not custom_configuration.models: - if system_configuration.enabled: - has_valid_quota = False - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.is_valid: - has_valid_quota = True - break + if system_configuration.enabled and has_valid_quota: + using_provider_type = ProviderType.SYSTEM - if has_valid_quota: - using_provider_type = ProviderType.SYSTEM + # Get provider load balancing configs + provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name) + + # Get provider load balancing configs + provider_load_balancing_configs \ + = provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name) + + # Convert to model settings + model_settings = self._to_model_settings( + provider_entity=provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=provider_load_balancing_configs + ) provider_configuration = ProviderConfiguration( tenant_id=tenant_id, @@ -173,7 +180,8 @@ class ProviderManager: preferred_provider_type=preferred_provider_type, using_provider_type=using_provider_type, system_configuration=system_configuration, - custom_configuration=custom_configuration + custom_configuration=custom_configuration, + model_settings=model_settings ) provider_configurations[provider_name] = provider_configuration @@ -233,30 +241,17 @@ class ProviderManager: ) if available_models: - found = False - for available_model in available_models: - if available_model.model == "gpt-4": - default_model = TenantDefaultModel( - tenant_id=tenant_id, - model_type=model_type.to_origin_model_type(), - provider_name=available_model.provider.provider, - model_name=available_model.model - ) - db.session.add(default_model) - db.session.commit() - found = True - break + available_model = next((model for model in available_models if model.model == "gpt-4"), + available_models[0]) - if not found: - available_model = available_models[0] - default_model = TenantDefaultModel( - tenant_id=tenant_id, - model_type=model_type.to_origin_model_type(), - provider_name=available_model.provider.provider, - model_name=available_model.model - ) - db.session.add(default_model) - db.session.commit() + default_model = TenantDefaultModel( + tenant_id=tenant_id, + model_type=model_type.to_origin_model_type(), + provider_name=available_model.provider.provider, + model_name=available_model.model + ) + db.session.add(default_model) + db.session.commit() if not default_model: return None @@ -371,7 +366,7 @@ class ProviderManager: """ Get All preferred provider types of the workspace. - :param tenant_id: + :param tenant_id: workspace id :return: """ preferred_provider_types = db.session.query(TenantPreferredModelProvider) \ @@ -386,6 +381,56 @@ class ProviderManager: return provider_name_to_preferred_provider_type_records_dict + def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]: + """ + Get All provider model settings of the workspace. + + :param tenant_id: workspace id + :return: + """ + provider_model_settings = db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == tenant_id + ).all() + + provider_name_to_provider_model_settings_dict = defaultdict(list) + for provider_model_setting in provider_model_settings: + (provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name] + .append(provider_model_setting)) + + return provider_name_to_provider_model_settings_dict + + def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: + """ + Get All provider load balancing configs of the workspace. + + :param tenant_id: workspace id + :return: + """ + cache_key = f"tenant:{tenant_id}:model_load_balancing_enabled" + cache_result = redis_client.get(cache_key) + if cache_result is None: + model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled + redis_client.setex(cache_key, 120, str(model_load_balancing_enabled)) + else: + cache_result = cache_result.decode('utf-8') + model_load_balancing_enabled = cache_result == 'True' + + if not model_load_balancing_enabled: + return dict() + + provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id + ).all() + + provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) + for provider_load_balancing_config in provider_load_balancing_configs: + (provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name] + .append(provider_load_balancing_config)) + + return provider_name_to_provider_load_balancing_model_configs_dict + def _init_trial_provider_records(self, tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: """ @@ -759,3 +804,97 @@ class ProviderManager: secret_input_form_variables.append(credential_form_schema.variable) return secret_input_form_variables + + def _to_model_settings(self, provider_entity: ProviderEntity, + provider_model_settings: Optional[list[ProviderModelSetting]] = None, + load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \ + -> list[ModelSettings]: + """ + Convert to model settings. + + :param provider_model_settings: provider model settings include enabled, load balancing enabled + :param load_balancing_model_configs: load balancing model configs + :return: + """ + # Get provider model credential secret variables + model_credential_secret_variables = self._extract_secret_variables( + provider_entity.model_credential_schema.credential_form_schemas + if provider_entity.model_credential_schema else [] + ) + + model_settings = [] + if not provider_model_settings: + return model_settings + + for provider_model_setting in provider_model_settings: + load_balancing_configs = [] + if provider_model_setting.load_balancing_enabled and load_balancing_model_configs: + for load_balancing_model_config in load_balancing_model_configs: + if (load_balancing_model_config.model_name == provider_model_setting.model_name + and load_balancing_model_config.model_type == provider_model_setting.model_type): + if not load_balancing_model_config.enabled: + continue + + if not load_balancing_model_config.encrypted_config: + if load_balancing_model_config.name == "__inherit__": + load_balancing_configs.append(ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials={} + )) + continue + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=load_balancing_model_config.tenant_id, + identity_id=load_balancing_model_config.id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + ) + + # Get cached provider model credentials + cached_provider_model_credentials = provider_model_credentials_cache.get() + + if not cached_provider_model_credentials: + try: + provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config) + except JSONDecodeError: + continue + + # Get decoding rsa key and cipher for decrypting credentials + if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: + self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding( + load_balancing_model_config.tenant_id) + + for variable in model_credential_secret_variables: + if variable in provider_model_credentials: + try: + provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( + provider_model_credentials.get(variable), + self.decoding_rsa_key, + self.decoding_cipher_rsa + ) + except ValueError: + pass + + # cache provider model credentials + provider_model_credentials_cache.set( + credentials=provider_model_credentials + ) + else: + provider_model_credentials = cached_provider_model_credentials + + load_balancing_configs.append(ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials=provider_model_credentials + )) + + model_settings.append( + ModelSettings( + model=provider_model_setting.model_name, + model_type=ModelType.value_of(provider_model_setting.model_type), + enabled=provider_model_setting.enabled, + load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [] + ) + ) + + return model_settings diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index bdd69c27b1..a0f2947784 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.models.document import Document -from core.rerank.rerank import RerankRunner +from core.rag.rerank.rerank import RerankRunner class DataPostProcessor: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 0f9c753056..dd74406f30 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -33,6 +33,7 @@ class RetrievalService: return [] all_documents = [] threads = [] + exceptions = [] # retrieval_model source with keyword if retrival_method == 'keyword_search': keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ @@ -40,7 +41,8 @@ class RetrievalService: 'dataset_id': dataset_id, 'query': query, 'top_k': top_k, - 'all_documents': all_documents + 'all_documents': all_documents, + 'exceptions': exceptions, }) threads.append(keyword_thread) keyword_thread.start() @@ -54,7 +56,8 @@ class RetrievalService: 'score_threshold': score_threshold, 'reranking_model': reranking_model, 'all_documents': all_documents, - 'retrival_method': retrival_method + 'retrival_method': retrival_method, + 'exceptions': exceptions, }) threads.append(embedding_thread) embedding_thread.start() @@ -69,7 +72,8 @@ class RetrievalService: 'score_threshold': score_threshold, 'top_k': top_k, 'reranking_model': reranking_model, - 'all_documents': all_documents + 'all_documents': all_documents, + 'exceptions': exceptions, }) threads.append(full_text_index_thread) full_text_index_thread.start() @@ -77,6 +81,10 @@ class RetrievalService: for thread in threads: thread.join() + if exceptions: + exception_message = ';\n'.join(exceptions) + raise Exception(exception_message) + if retrival_method == 'hybrid_search': data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) all_documents = data_post_processor.invoke( @@ -89,82 +97,91 @@ class RetrievalService: @classmethod def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, all_documents: list): + top_k: int, all_documents: list, exceptions: list): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + try: + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() - keyword = Keyword( - dataset=dataset - ) + keyword = Keyword( + dataset=dataset + ) - documents = keyword.search( - query, - top_k=top_k - ) - all_documents.extend(documents) + documents = keyword.search( + query, + top_k=top_k + ) + all_documents.extend(documents) + except Exception as e: + exceptions.append(str(e)) @classmethod def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrival_method: str): + all_documents: list, retrival_method: str, exceptions: list): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + try: + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() - vector = Vector( - dataset=dataset - ) + vector = Vector( + dataset=dataset + ) - documents = vector.search_by_vector( - query, - search_type='similarity_score_threshold', - top_k=top_k, - score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } - ) + documents = vector.search_by_vector( + query, + search_type='similarity_score_threshold', + top_k=top_k, + score_threshold=score_threshold, + filter={ + 'group_id': [dataset.id] + } + ) - if documents: - if reranking_model and retrival_method == 'semantic_search': - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) - else: - all_documents.extend(documents) + if documents: + if reranking_model and retrival_method == 'semantic_search': + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) + all_documents.extend(data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents) + )) + else: + all_documents.extend(documents) + except Exception as e: + exceptions.append(str(e)) @classmethod def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrival_method: str): + all_documents: list, retrival_method: str, exceptions: list): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + try: + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() - vector_processor = Vector( - dataset=dataset, - ) + vector_processor = Vector( + dataset=dataset, + ) - documents = vector_processor.search_by_full_text( - query, - top_k=top_k - ) - if documents: - if reranking_model and retrival_method == 'full_text_search': - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) - else: - all_documents.extend(documents) + documents = vector_processor.search_by_full_text( + query, + top_k=top_k + ) + if documents: + if reranking_model and retrival_method == 'full_text_search': + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) + all_documents.extend(data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents) + )) + else: + all_documents.extend(documents) + except Exception as e: + exceptions.append(str(e)) diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py index dc400dafbb..1c16e4d9cd 100644 --- a/api/core/rag/datasource/vdb/field.py +++ b/api/core/rag/datasource/vdb/field.py @@ -8,3 +8,4 @@ class Field(Enum): VECTOR = "vector" TEXT_KEY = "text" PRIMARY_KEY = "id" + DOC_ID = "metadata.doc_id" diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index c90fe3b188..0586e279d3 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -259,5 +259,5 @@ class MilvusVector(BaseVector): uri = "https://" + str(config.host) + ":" + str(config.port) else: uri = "http://" + str(config.host) + ":" + str(config.port) - client = MilvusClient(uri=uri, user=config.user, password=config.password) + client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database) return client diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 5735b79b6e..3842aee6c7 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -54,7 +54,7 @@ class PGVectoRS(BaseVector): class _Table(CollectionORM): __tablename__ = collection_name - __table_args__ = {"extend_existing": True} # noqa: RUF012 + __table_args__ = {"extend_existing": True} id: Mapped[UUID] = mapped_column( postgresql.UUID(as_uuid=True), primary_key=True, diff --git a/api/core/rag/datasource/vdb/pgvector/__init__.py b/api/core/rag/datasource/vdb/pgvector/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py new file mode 100644 index 0000000000..22cf790bfa --- /dev/null +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -0,0 +1,169 @@ +import json +import uuid +from contextlib import contextmanager +from typing import Any + +import psycopg2.extras +import psycopg2.pool +from pydantic import BaseModel, root_validator + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document +from extensions.ext_redis import redis_client + + +class PGVectorConfig(BaseModel): + host: str + port: int + user: str + password: str + database: str + + @root_validator() + def validate_config(cls, values: dict) -> dict: + if not values["host"]: + raise ValueError("config PGVECTOR_HOST is required") + if not values["port"]: + raise ValueError("config PGVECTOR_PORT is required") + if not values["user"]: + raise ValueError("config PGVECTOR_USER is required") + if not values["password"]: + raise ValueError("config PGVECTOR_PASSWORD is required") + if not values["database"]: + raise ValueError("config PGVECTOR_DATABASE is required") + return values + + +SQL_CREATE_TABLE = """ +CREATE TABLE IF NOT EXISTS {table_name} ( + id UUID PRIMARY KEY, + text TEXT NOT NULL, + meta JSONB NOT NULL, + embedding vector({dimension}) NOT NULL +) using heap; +""" + + +class PGVector(BaseVector): + def __init__(self, collection_name: str, config: PGVectorConfig): + super().__init__(collection_name) + self.pool = self._create_connection_pool(config) + self.table_name = f"embedding_{collection_name}" + + def get_type(self) -> str: + return "pgvector" + + def _create_connection_pool(self, config: PGVectorConfig): + return psycopg2.pool.SimpleConnectionPool( + 1, + 5, + host=config.host, + port=config.port, + user=config.user, + password=config.password, + database=config.database, + ) + + @contextmanager + def _get_cursor(self): + conn = self.pool.getconn() + cur = conn.cursor() + try: + yield cur + finally: + cur.close() + conn.commit() + self.pool.putconn(conn) + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + dimension = len(embeddings[0]) + self._create_collection(dimension) + return self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + values = [] + pks = [] + for i, doc in enumerate(documents): + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + embeddings[i], + ) + ) + with self._get_cursor() as cur: + psycopg2.extras.execute_values( + cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values + ) + return pks + + def text_exists(self, id: str) -> bool: + with self._get_cursor() as cur: + cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,)) + return cur.fetchone() is not None + + def get_by_ids(self, ids: list[str]) -> list[Document]: + with self._get_cursor() as cur: + cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) + docs = [] + for record in cur: + docs.append(Document(page_content=record[1], metadata=record[0])) + return docs + + def delete_by_ids(self, ids: list[str]) -> None: + with self._get_cursor() as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + with self._get_cursor() as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """ + Search the nearest neighbors to a vector. + + :param query_vector: The input vector to search for similar items. + :param top_k: The number of nearest neighbors to return, default is 5. + :return: List of Documents that are nearest to the query vector. + """ + top_k = kwargs.get("top_k", 5) + + with self._get_cursor() as cur: + cur.execute( + f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name} ORDER BY distance LIMIT {top_k}", + (json.dumps(query_vector),), + ) + docs = [] + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + for record in cur: + metadata, text, distance = record + score = 1 - distance + metadata["score"] = score + if score > score_threshold: + docs.append(Document(page_content=text, metadata=metadata)) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # do not support bm25 search + return [] + + def delete(self) -> None: + with self._get_cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") + + def _create_collection(self, dimension: int): + cache_key = f"vector_indexing_{self._collection_name}" + lock_name = f"{cache_key}_lock" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + return + + with self._get_cursor() as cur: + cur.execute("CREATE EXTENSION IF NOT EXISTS vector") + cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) + # TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing + redis_client.set(collection_exist_cache_key, 1, ex=3600) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index e6e83c66d8..7a92314542 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -115,9 +115,12 @@ class QdrantVector(BaseVector): timeout=int(self._client_config.timeout), ) - # create payload index + # create group_id payload index self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD) + # create doc_id payload index + self._client.create_payload_index(collection_name, Field.DOC_ID.value, + field_schema=PayloadSchemaType.KEYWORD) # creat full text index text_index_params = TextIndexParams( type=TextIndexType.TEXT, diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 74b91db27e..ee88d9fa29 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -190,7 +190,7 @@ class RelytVector(BaseVector): conn.execute(chunks_table.delete().where(delete_condition)) return True except Exception as e: - print("Delete operation failed:", str(e)) # noqa: T201 + print("Delete operation failed:", str(e)) return False def delete_by_metadata_field(self, key: str, value: str): diff --git a/api/core/rag/datasource/vdb/tidb_vector/__init__.py b/api/core/rag/datasource/vdb/tidb_vector/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py new file mode 100644 index 0000000000..107d17bb47 --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -0,0 +1,216 @@ +import json +import logging +from typing import Any + +import sqlalchemy +from pydantic import BaseModel, root_validator +from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert +from sqlalchemy import text as sql_text +from sqlalchemy.orm import Session, declarative_base + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + + +class TiDBVectorConfig(BaseModel): + host: str + port: int + user: str + password: str + database: str + + @root_validator() + def validate_config(cls, values: dict) -> dict: + if not values['host']: + raise ValueError("config TIDB_VECTOR_HOST is required") + if not values['port']: + raise ValueError("config TIDB_VECTOR_PORT is required") + if not values['user']: + raise ValueError("config TIDB_VECTOR_USER is required") + if not values['password']: + raise ValueError("config TIDB_VECTOR_PASSWORD is required") + if not values['database']: + raise ValueError("config TIDB_VECTOR_DATABASE is required") + return values + + +class TiDBVector(BaseVector): + + def _table(self, dim: int) -> Table: + from tidb_vector.sqlalchemy import VectorType + return Table( + self._collection_name, + self._orm_base.metadata, + Column('id', String(36), primary_key=True, nullable=False), + Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"), + Column("text", TEXT, nullable=False), + Column("meta", JSON, nullable=False), + Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")), + Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")), + extend_existing=True + ) + + def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'): + super().__init__(collection_name) + self._client_config = config + self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" + f"ssl_verify_cert=true&ssl_verify_identity=true") + self._distance_func = distance_func.lower() + self._engine = create_engine(self._url) + self._orm_base = declarative_base() + self._dimension = 1536 + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + logger.info("create collection and add texts, collection_name: " + self._collection_name) + self._create_collection(len(embeddings[0])) + self.add_texts(texts, embeddings) + self._dimension = len(embeddings[0]) + pass + + def _create_collection(self, dimension: int): + logger.info("_create_collection, collection_name " + self._collection_name) + lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + with Session(self._engine) as session: + session.begin() + create_statement = sql_text(f""" + CREATE TABLE IF NOT EXISTS {self._collection_name} ( + id CHAR(36) PRIMARY KEY, + text TEXT NOT NULL, + meta JSON NOT NULL, + doc_id VARCHAR(64) AS (JSON_UNQUOTE(JSON_EXTRACT(meta, '$.doc_id'))) STORED, + KEY (doc_id), + vector VECTOR({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})", + create_time DATETIME DEFAULT CURRENT_TIMESTAMP, + update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + ); + """) + session.execute(create_statement) + # tidb vector not support 'CREATE/ADD INDEX' now + session.commit() + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + table = self._table(len(embeddings[0])) + ids = self._get_uuids(documents) + metas = [d.metadata for d in documents] + texts = [d.page_content for d in documents] + + chunks_table_data = [] + with self._engine.connect() as conn: + with conn.begin(): + for id, text, meta, embedding in zip( + ids, texts, metas, embeddings + ): + chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) + + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == 500: + conn.execute(insert(table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(table).values(chunks_table_data)) + return ids + + def text_exists(self, id: str) -> bool: + result = self.get_ids_by_metadata_field('doc_id', id) + return bool(result) + + def delete_by_ids(self, ids: list[str]) -> None: + with Session(self._engine) as session: + ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + select_statement = sql_text( + f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """ + ) + result = session.execute(select_statement).fetchall() + if result: + ids = [item[0] for item in result] + self._delete_by_ids(ids) + + def _delete_by_ids(self, ids: list[str]) -> bool: + if ids is None: + raise ValueError("No ids provided to delete.") + table = self._table(self._dimension) + try: + with self._engine.connect() as conn: + with conn.begin(): + delete_condition = table.c.id.in_(ids) + conn.execute(table.delete().where(delete_condition)) + return True + except Exception as e: + print("Delete operation failed:", str(e)) + return False + + def delete_by_document_id(self, document_id: str): + ids = self.get_ids_by_metadata_field('document_id', document_id) + if ids: + self._delete_by_ids(ids) + + def get_ids_by_metadata_field(self, key: str, value: str): + with Session(self._engine) as session: + select_statement = sql_text( + f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.{key}' = '{value}'; """ + ) + result = session.execute(select_statement).fetchall() + if result: + return [item[0] for item in result] + else: + return None + + def delete_by_metadata_field(self, key: str, value: str) -> None: + ids = self.get_ids_by_metadata_field(key, value) + if ids: + self._delete_by_ids(ids) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 5) + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + filter = kwargs.get('filter') + distance = 1 - score_threshold + + query_vector_str = ", ".join(format(x) for x in query_vector) + query_vector_str = "[" + query_vector_str + "]" + logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}") + + docs = [] + if self._distance_func == 'l2': + tidb_func = 'Vec_l2_distance' + elif self._distance_func == 'cosine': + tidb_func = 'Vec_Cosine_distance' + else: + tidb_func = 'Vec_Cosine_distance' + + with Session(self._engine) as session: + select_statement = sql_text( + f"""SELECT meta, text, distance FROM ( + SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance + FROM {self._collection_name} + ORDER BY distance + LIMIT {top_k} + ) t WHERE distance < {distance};""" + ) + res = session.execute(select_statement) + results = [(row[0], row[1], row[2]) for row in res] + for meta, text, distance in results: + metadata = json.loads(meta) + metadata['score'] = 1 - distance + docs.append(Document(page_content=text, metadata=metadata)) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # tidb doesn't support bm25 search + return [] + + def delete(self) -> None: + with Session(self._engine) as session: + session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) + session.commit() diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 2405d16b1d..b500b37d60 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -164,6 +164,54 @@ class Vector: ), dim=dim ) + elif vector_type == "pgvector": + from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig + + if self._dataset.index_struct_dict: + class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = self._dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = { + "type": "pgvector", + "vector_store": {"class_prefix": collection_name}} + self._dataset.index_struct = json.dumps(index_struct_dict) + return PGVector( + collection_name=collection_name, + config=PGVectorConfig( + host=config.get("PGVECTOR_HOST"), + port=config.get("PGVECTOR_PORT"), + user=config.get("PGVECTOR_USER"), + password=config.get("PGVECTOR_PASSWORD"), + database=config.get("PGVECTOR_DATABASE"), + ), + ) + elif vector_type == "tidb_vector": + from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig + + if self._dataset.index_struct_dict: + class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix.lower() + else: + dataset_id = self._dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + index_struct_dict = { + "type": 'tidb_vector', + "vector_store": {"class_prefix": collection_name} + } + self._dataset.index_struct = json.dumps(index_struct_dict) + + return TiDBVector( + collection_name=collection_name, + config=TiDBVectorConfig( + host=config.get('TIDB_VECTOR_HOST'), + port=config.get('TIDB_VECTOR_PORT'), + user=config.get('TIDB_VECTOR_USER'), + password=config.get('TIDB_VECTOR_PASSWORD'), + database=config.get('TIDB_VECTOR_DATABASE'), + ), + ) else: raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") diff --git a/api/core/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py similarity index 91% rename from api/core/docstore/dataset_docstore.py rename to api/core/rag/docstore/dataset_docstore.py index 7567493b9f..96a15be742 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -1,11 +1,10 @@ from collections.abc import Sequence -from typing import Any, Optional, cast +from typing import Any, Optional from sqlalchemy import func from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment @@ -95,11 +94,7 @@ class DatasetDocumentStore: # calc embedding use tokens if embedding_model: - model_type_instance = embedding_model.model_type_instance - model_type_instance = cast(TextEmbeddingModel, model_type_instance) - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, + tokens = embedding_model.get_text_embedding_num_tokens( texts=[doc.page_content] ) else: @@ -121,13 +116,13 @@ class DatasetDocumentStore: enabled=False, created_by=self._user_id, ) - if 'answer' in doc.metadata and doc.metadata['answer']: + if doc.metadata.get('answer'): segment_document.answer = doc.metadata.pop('answer', '') db.session.add(segment_document) else: segment_document.content = doc.page_content - if 'answer' in doc.metadata and doc.metadata['answer']: + if doc.metadata.get('answer'): segment_document.answer = doc.metadata.pop('answer', '') segment_document.index_node_hash = doc.metadata['doc_hash'] segment_document.word_count = len(doc.page_content) diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 09a1cddd1e..0470569f39 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -57,7 +57,7 @@ class CSVExtractor(BaseExtractor): docs = [] try: # load csv file into pandas dataframe - df = pd.read_csv(csvfile, error_bad_lines=False, **self.csv_args) + df = pd.read_csv(csvfile, on_bad_lines='skip', **self.csv_args) # check source column exists if self.source_column and self.source_column not in df.columns: diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index 2b0066448e..4d2f61139a 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -39,8 +39,8 @@ class ExcelExtractor(BaseExtractor): documents = [] # loop over all sheets for sheet in wb.sheets(): - for row_index, row in enumerate(sheet.get_rows(), start=1): - row_header = None + row_header = None + for row_index, row in enumerate(sheet.get_rows(), start=1): if self.is_blank_row(row): continue if row_header is None: @@ -49,8 +49,8 @@ class ExcelExtractor(BaseExtractor): item_arr = [] for index, cell in enumerate(row): txt_value = str(cell.value) - item_arr.append(f'{row_header[index].value}:{txt_value}') - item_str = "\n".join(item_arr) + item_arr.append(f'"{row_header[index].value}":"{txt_value}"') + item_str = ",".join(item_arr) document = Document(page_content=item_str, metadata={'source': self._file_path}) documents.append(document) return documents @@ -68,7 +68,7 @@ class ExcelExtractor(BaseExtractor): # transform each row into a Document for _, row in df.iterrows(): - item = ';'.join(f'{k}:{v}' for k, v in row.items() if pd.notna(v)) + item = ';'.join(f'"{k}":"{v}"' for k, v in row.items() if pd.notna(v)) document = Document(page_content=item, metadata={'source': self._file_path}) data.append(document) return data diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 1136e11f76..09d192d410 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -1,6 +1,8 @@ +import re import tempfile from pathlib import Path from typing import Union +from urllib.parse import unquote import requests from flask import current_app @@ -14,7 +16,6 @@ from core.rag.extractor.markdown_extractor import MarkdownExtractor from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.pdf_extractor import PdfExtractor from core.rag.extractor.text_extractor import TextExtractor -from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor @@ -28,7 +29,7 @@ from core.rag.models.document import Document from extensions.ext_storage import storage from models.model import UploadFile -SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] +SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain', 'application/json'] USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" @@ -55,6 +56,17 @@ class ExtractProcessor: with tempfile.TemporaryDirectory() as temp_dir: suffix = Path(url).suffix + if not suffix and suffix != '.': + # get content-type + if response.headers.get('Content-Type'): + suffix = '.' + response.headers.get('Content-Type').split('/')[-1] + else: + content_disposition = response.headers.get('Content-Disposition') + filename_match = re.search(r'filename="([^"]+)"', content_disposition) + if filename_match: + filename = unquote(filename_match.group(1)) + suffix = '.' + re.search(r'\.(\w+)$', filename).group(1) + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" with open(file_path, 'wb') as file: file.write(response.content) @@ -83,6 +95,7 @@ class ExtractProcessor: file_extension = input_file.suffix.lower() etl_type = current_app.config['ETL_TYPE'] unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] + unstructured_api_key = current_app.config['UNSTRUCTURED_API_KEY'] if etl_type == 'Unstructured': if file_extension == '.xlsx' or file_extension == '.xls': extractor = ExcelExtractor(file_path) @@ -94,7 +107,7 @@ class ExtractProcessor: elif file_extension in ['.htm', '.html']: extractor = HtmlExtractor(file_path) elif file_extension in ['.docx']: - extractor = UnstructuredWordExtractor(file_path, unstructured_api_url) + extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == '.csv': extractor = CSVExtractor(file_path, autodetect_encoding=True) elif file_extension == '.msg': @@ -102,7 +115,7 @@ class ExtractProcessor: elif file_extension == '.eml': extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url) elif file_extension == '.ppt': - extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url) + extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key) elif file_extension == '.pptx': extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url) elif file_extension == '.xml': @@ -123,7 +136,7 @@ class ExtractProcessor: elif file_extension in ['.htm', '.html']: extractor = HtmlExtractor(file_path) elif file_extension in ['.docx']: - extractor = WordExtractor(file_path) + extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == '.csv': extractor = CSVExtractor(file_path, autodetect_encoding=True) elif file_extension == 'epub': diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index c40064fd1d..1885ad3aca 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -19,8 +19,12 @@ SEARCH_URL = "https://api.notion.com/v1/search" RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" -HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] - +# if user want split by headings, use the corresponding splitter +HEADING_SPLITTER = { + 'heading_1': '# ', + 'heading_2': '## ', + 'heading_3': '### ', +} class NotionExtractor(BaseExtractor): @@ -73,8 +77,7 @@ class NotionExtractor(BaseExtractor): docs.extend(page_text_documents) elif notion_page_type == 'page': page_text_list = self._get_notion_block_data(notion_obj_id) - for page_text in page_text_list: - docs.append(Document(page_content=page_text)) + docs.append(Document(page_content='\n'.join(page_text_list))) else: raise ValueError("notion page type not supported") @@ -96,7 +99,7 @@ class NotionExtractor(BaseExtractor): data = res.json() - database_content_list = [] + database_content = [] if 'results' not in data or data["results"] is None: return [] for result in data["results"]: @@ -131,10 +134,9 @@ class NotionExtractor(BaseExtractor): row_content = row_content + f'{key}:{value_content}\n' else: row_content = row_content + f'{key}:{value}\n' - document = Document(page_content=row_content) - database_content_list.append(document) + database_content.append(row_content) - return database_content_list + return [Document(page_content='\n'.join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: result_lines_arr = [] @@ -154,8 +156,6 @@ class NotionExtractor(BaseExtractor): json=query_dict ) data = res.json() - # current block's heading - heading = '' for result in data["results"]: result_type = result["type"] result_obj = result[result_type] @@ -172,8 +172,6 @@ class NotionExtractor(BaseExtractor): if "text" in rich_text: text = rich_text["text"]["content"] cur_result_text_arr.append(text) - if result_type in HEADING_TYPE: - heading = text result_block_id = result["id"] has_children = result["has_children"] @@ -185,11 +183,10 @@ class NotionExtractor(BaseExtractor): cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) - cur_result_text += "\n\n" - if result_type in HEADING_TYPE: - result_lines_arr.append(cur_result_text) + if result_type in HEADING_SPLITTER: + result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(f'{heading}\n{cur_result_text}') + result_lines_arr.append(cur_result_text + '\n\n') if data["next_cursor"] is None: break @@ -218,7 +215,6 @@ class NotionExtractor(BaseExtractor): data = res.json() if 'results' not in data or data["results"] is None: break - heading = '' for result in data["results"]: result_type = result["type"] result_obj = result[result_type] @@ -235,8 +231,6 @@ class NotionExtractor(BaseExtractor): text = rich_text["text"]["content"] prefix = "\t" * num_tabs cur_result_text_arr.append(prefix + text) - if result_type in HEADING_TYPE: - heading = text result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] @@ -247,10 +241,10 @@ class NotionExtractor(BaseExtractor): cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) - if result_type in HEADING_TYPE: - result_lines_arr.append(cur_result_text) + if result_type in HEADING_SPLITTER: + result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}') else: - result_lines_arr.append(f'{heading}\n{cur_result_text}') + result_lines_arr.append(cur_result_text + '\n\n') if data["next_cursor"] is None: break diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index 6d3ffe6589..d354b593ed 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -17,16 +17,18 @@ class UnstructuredPPTExtractor(BaseExtractor): def __init__( self, file_path: str, - api_url: str + api_url: str, + api_key: str ): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url + self._api_key = api_key def extract(self) -> list[Document]: from unstructured.partition.api import partition_via_api - elements = partition_via_api(filename=self._file_path, api_url=self._api_url) + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) text_by_page = {} for element in elements: page = element.metadata.page_number diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 6bf47a76f0..5b858c6c4c 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,12 +1,20 @@ """Abstract interface for document loader implementations.""" +import datetime +import mimetypes import os import tempfile +import uuid from urllib.parse import urlparse import requests +from docx import Document as DocxDocument +from flask import current_app from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import UploadFile class WordExtractor(BaseExtractor): @@ -17,9 +25,12 @@ class WordExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str): + def __init__(self, file_path: str, tenant_id: str, user_id: str): """Initialize with file path.""" self.file_path = file_path + self.tenant_id = tenant_id + self.user_id = user_id + if "~" in self.file_path: self.file_path = os.path.expanduser(self.file_path) @@ -45,12 +56,7 @@ class WordExtractor(BaseExtractor): def extract(self) -> list[Document]: """Load given path as single page.""" - from docx import Document as docx_Document - - document = docx_Document(self.file_path) - doc_texts = [paragraph.text for paragraph in document.paragraphs] - content = '\n'.join(doc_texts) - + content = self.parse_docx(self.file_path, 'storage') return [Document( page_content=content, metadata={"source": self.file_path}, @@ -61,3 +67,111 @@ class WordExtractor(BaseExtractor): """Check if the url is valid.""" parsed = urlparse(url) return bool(parsed.netloc) and bool(parsed.scheme) + + def _extract_images_from_docx(self, doc, image_folder): + os.makedirs(image_folder, exist_ok=True) + image_count = 0 + image_map = {} + + for rel in doc.part.rels.values(): + if "image" in rel.target_ref: + image_count += 1 + image_ext = rel.target_ref.split('.')[-1] + # user uuid as file name + file_uuid = str(uuid.uuid4()) + file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext + mime_type, _ = mimetypes.guess_type(file_key) + + storage.save(file_key, rel.target_part.blob) + # save file to db + config = current_app.config + upload_file = UploadFile( + tenant_id=self.tenant_id, + storage_type=config['STORAGE_TYPE'], + key=file_key, + name=file_key, + size=0, + extension=image_ext, + mime_type=mime_type, + created_by=self.user_id, + created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + used=True, + used_by=self.user_id, + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + ) + + db.session.add(upload_file) + db.session.commit() + image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)" + + return image_map + + def _table_to_markdown(self, table): + markdown = "" + # deal with table headers + header_row = table.rows[0] + headers = [cell.text for cell in header_row.cells] + markdown += "| " + " | ".join(headers) + " |\n" + markdown += "| " + " | ".join(["---"] * len(headers)) + " |\n" + # deal with table rows + for row in table.rows[1:]: + row_cells = [cell.text for cell in row.cells] + markdown += "| " + " | ".join(row_cells) + " |\n" + + return markdown + + def _parse_paragraph(self, paragraph, image_map): + paragraph_content = [] + for run in paragraph.runs: + if run.element.xpath('.//a:blip'): + for blip in run.element.xpath('.//a:blip'): + embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + if embed_id: + rel_target = run.part.rels[embed_id].target_ref + if rel_target in image_map: + paragraph_content.append(image_map[rel_target]) + if run.text.strip(): + paragraph_content.append(run.text.strip()) + return ' '.join(paragraph_content) if paragraph_content else '' + + def parse_docx(self, docx_path, image_folder): + doc = DocxDocument(docx_path) + os.makedirs(image_folder, exist_ok=True) + + content = [] + + image_map = self._extract_images_from_docx(doc, image_folder) + + def parse_paragraph(paragraph): + paragraph_content = [] + for run in paragraph.runs: + if run.element.tag.endswith('r'): + drawing_elements = run.element.findall( + './/{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing') + for drawing in drawing_elements: + blip_elements = drawing.findall( + './/{http://schemas.openxmlformats.org/drawingml/2006/main}blip') + for blip in blip_elements: + embed_id = blip.get( + '{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + if embed_id: + image_part = doc.part.related_parts.get(embed_id) + if image_part in image_map: + paragraph_content.append(image_map[image_part]) + if run.text.strip(): + paragraph_content.append(run.text.strip()) + return ''.join(paragraph_content) if paragraph_content else '' + + paragraphs = doc.paragraphs.copy() + tables = doc.tables.copy() + for element in doc.element.body: + if element.tag.endswith('p'): # paragraph + para = paragraphs.pop(0) + parsed_paragraph = parse_paragraph(para) + if parsed_paragraph: + content.append(parsed_paragraph) + elif element.tag.endswith('tbl'): # table + table = tables.pop(0) + content.append(self._table_to_markdown(table)) + return '\n'.join(content) + diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 509a1a189b..edc16c821a 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -2,11 +2,16 @@ from abc import ABC, abstractmethod from typing import Optional +from flask import current_app + from core.model_manager import ModelInstance from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.models.document import Document -from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter -from core.splitter.text_splitter import TextSplitter +from core.rag.splitter.fixed_text_splitter import ( + EnhanceRecursiveCharacterTextSplitter, + FixedRecursiveCharacterTextSplitter, +) +from core.rag.splitter.text_splitter import TextSplitter from models.dataset import Dataset, DatasetProcessRule @@ -43,8 +48,9 @@ class BaseIndexProcessor(ABC): # The user-defined segmentation rule rules = processing_rule['rules'] segmentation = rules["segmentation"] - if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000: - raise ValueError("Custom segment length should be between 50 and 1000.") + max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH']) + if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: + raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") separator = segmentation["separator"] if separator: @@ -54,7 +60,7 @@ class BaseIndexProcessor(ABC): chunk_size=segmentation["max_tokens"], chunk_overlap=segmentation.get('chunk_overlap', 0), fixed_separator=separator, - separators=["\n\n", "。", ".", " ", ""], + separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance ) else: @@ -62,7 +68,7 @@ class BaseIndexProcessor(ABC): character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], - separators=["\n\n", "。", ".", " ", ""], + separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance ) diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 221318c2c3..7bb675b149 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -50,7 +50,7 @@ class BaseDocumentTransformer(ABC): ) -> Sequence[Document]: raise NotImplementedError - """ # noqa: E501 + """ @abstractmethod def transform_documents( diff --git a/api/core/rag/rerank/__init__.py b/api/core/rag/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rerank/rerank.py b/api/core/rag/rerank/rerank.py similarity index 100% rename from api/core/rerank/rerank.py rename to api/core/rag/rerank/rerank.py diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 155b8be06c..b42a441a3f 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -14,9 +14,9 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document +from core.rag.rerank.rerank import RerankRunner from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter -from core.rerank.rerank import RerankRunner from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool @@ -124,7 +124,7 @@ class DatasetRetrieval: document_score_list = {} for item in all_documents: - if 'score' in item.metadata and item.metadata['score']: + if item.metadata.get('score'): document_score_list[item.metadata['doc_id']] = item.metadata['score'] document_context_list = [] @@ -144,9 +144,9 @@ class DatasetRetrieval: float('inf'))) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') else: - document_context_list.append(segment.content) + document_context_list.append(segment.get_sign_content()) if show_retrieve_source: context_list = [] resource_number = 1 @@ -329,6 +329,7 @@ class DatasetRetrieval: """ if not query: return + dataset_queries = [] for dataset_id in dataset_ids: dataset_query = DatasetQuery( dataset_id=dataset_id, @@ -338,7 +339,9 @@ class DatasetRetrieval: created_by_role=user_from, created_by=user_id ) - db.session.add(dataset_query) + dataset_queries.append(dataset_query) + if dataset_queries: + db.session.add_all(dataset_queries) db.session.commit() def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): diff --git a/api/core/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py similarity index 86% rename from api/core/splitter/fixed_text_splitter.py rename to api/core/rag/splitter/fixed_text_splitter.py index a1510259ac..fd714edf5e 100644 --- a/api/core/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -1,12 +1,11 @@ """Functionality for splitting text.""" from __future__ import annotations -from typing import Any, Optional, cast +from typing import Any, Optional from core.model_manager import ModelInstance -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer -from core.splitter.text_splitter import ( +from core.rag.splitter.text_splitter import ( TS, Collection, Literal, @@ -35,11 +34,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return 0 if embedding_model_instance: - embedding_model_type_instance = embedding_model_instance.model_type_instance - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) - return embedding_model_type_instance.get_num_tokens( - model=embedding_model_instance.model, - credentials=embedding_model_instance.credentials, + return embedding_model_instance.get_text_embedding_num_tokens( texts=[text] ) else: diff --git a/api/core/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py similarity index 64% rename from api/core/splitter/text_splitter.py rename to api/core/rag/splitter/text_splitter.py index 5eeb237a96..b3adcedc76 100644 --- a/api/core/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -6,7 +6,6 @@ import re from abc import ABC, abstractmethod from collections.abc import Callable, Collection, Iterable, Sequence, Set from dataclasses import dataclass -from enum import Enum from typing import ( Any, Literal, @@ -94,7 +93,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): documents.append(new_doc) return documents - def split_documents(self, documents: Iterable[Document] ) -> list[Document]: + def split_documents(self, documents: Iterable[Document]) -> list[Document]: """Split documents.""" texts, metadatas = [], [] for doc in documents: @@ -477,27 +476,6 @@ class TokenTextSplitter(TextSplitter): return split_text_on_tokens(text=text, tokenizer=tokenizer) -class Language(str, Enum): - """Enum of the programming languages.""" - - CPP = "cpp" - GO = "go" - JAVA = "java" - JS = "js" - PHP = "php" - PROTO = "proto" - PYTHON = "python" - RST = "rst" - RUBY = "ruby" - RUST = "rust" - SCALA = "scala" - SWIFT = "swift" - MARKDOWN = "markdown" - LATEX = "latex" - HTML = "html" - SOL = "sol" - - class RecursiveCharacterTextSplitter(TextSplitter): """Splitting text by recursively look at characters. @@ -554,350 +532,3 @@ class RecursiveCharacterTextSplitter(TextSplitter): def split_text(self, text: str) -> list[str]: return self._split_text(text, self._separators) - - @classmethod - def from_language( - cls, language: Language, **kwargs: Any - ) -> RecursiveCharacterTextSplitter: - separators = cls.get_separators_for_language(language) - return cls(separators=separators, **kwargs) - - @staticmethod - def get_separators_for_language(language: Language) -> list[str]: - if language == Language.CPP: - return [ - # Split along class definitions - "\nclass ", - # Split along function definitions - "\nvoid ", - "\nint ", - "\nfloat ", - "\ndouble ", - # Split along control flow statements - "\nif ", - "\nfor ", - "\nwhile ", - "\nswitch ", - "\ncase ", - # Split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.GO: - return [ - # Split along function definitions - "\nfunc ", - "\nvar ", - "\nconst ", - "\ntype ", - # Split along control flow statements - "\nif ", - "\nfor ", - "\nswitch ", - "\ncase ", - # Split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.JAVA: - return [ - # Split along class definitions - "\nclass ", - # Split along method definitions - "\npublic ", - "\nprotected ", - "\nprivate ", - "\nstatic ", - # Split along control flow statements - "\nif ", - "\nfor ", - "\nwhile ", - "\nswitch ", - "\ncase ", - # Split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.JS: - return [ - # Split along function definitions - "\nfunction ", - "\nconst ", - "\nlet ", - "\nvar ", - "\nclass ", - # Split along control flow statements - "\nif ", - "\nfor ", - "\nwhile ", - "\nswitch ", - "\ncase ", - "\ndefault ", - # Split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.PHP: - return [ - # Split along function definitions - "\nfunction ", - # Split along class definitions - "\nclass ", - # Split along control flow statements - "\nif ", - "\nforeach ", - "\nwhile ", - "\ndo ", - "\nswitch ", - "\ncase ", - # Split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.PROTO: - return [ - # Split along message definitions - "\nmessage ", - # Split along service definitions - "\nservice ", - # Split along enum definitions - "\nenum ", - # Split along option definitions - "\noption ", - # Split along import statements - "\nimport ", - # Split along syntax declarations - "\nsyntax ", - # Split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.PYTHON: - return [ - # First, try to split along class definitions - "\nclass ", - "\ndef ", - "\n\tdef ", - # Now split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.RST: - return [ - # Split along section titles - "\n=+\n", - "\n-+\n", - "\n\*+\n", - # Split along directive markers - "\n\n.. *\n\n", - # Split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.RUBY: - return [ - # Split along method definitions - "\ndef ", - "\nclass ", - # Split along control flow statements - "\nif ", - "\nunless ", - "\nwhile ", - "\nfor ", - "\ndo ", - "\nbegin ", - "\nrescue ", - # Split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.RUST: - return [ - # Split along function definitions - "\nfn ", - "\nconst ", - "\nlet ", - # Split along control flow statements - "\nif ", - "\nwhile ", - "\nfor ", - "\nloop ", - "\nmatch ", - "\nconst ", - # Split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.SCALA: - return [ - # Split along class definitions - "\nclass ", - "\nobject ", - # Split along method definitions - "\ndef ", - "\nval ", - "\nvar ", - # Split along control flow statements - "\nif ", - "\nfor ", - "\nwhile ", - "\nmatch ", - "\ncase ", - # Split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.SWIFT: - return [ - # Split along function definitions - "\nfunc ", - # Split along class definitions - "\nclass ", - "\nstruct ", - "\nenum ", - # Split along control flow statements - "\nif ", - "\nfor ", - "\nwhile ", - "\ndo ", - "\nswitch ", - "\ncase ", - # Split by the normal type of lines - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.MARKDOWN: - return [ - # First, try to split along Markdown headings (starting with level 2) - "\n#{1,6} ", - # Note the alternative syntax for headings (below) is not handled here - # Heading level 2 - # --------------- - # End of code block - "```\n", - # Horizontal lines - "\n\*\*\*+\n", - "\n---+\n", - "\n___+\n", - # Note that this splitter doesn't handle horizontal lines defined - # by *three or more* of ***, ---, or ___, but this is not handled - "\n\n", - "\n", - " ", - "", - ] - elif language == Language.LATEX: - return [ - # First, try to split along Latex sections - "\n\\\chapter{", - "\n\\\section{", - "\n\\\subsection{", - "\n\\\subsubsection{", - # Now split by environments - "\n\\\begin{enumerate}", - "\n\\\begin{itemize}", - "\n\\\begin{description}", - "\n\\\begin{list}", - "\n\\\begin{quote}", - "\n\\\begin{quotation}", - "\n\\\begin{verse}", - "\n\\\begin{verbatim}", - # Now split by math environments - "\n\\\begin{align}", - "$$", - "$", - # Now split by the normal type of lines - " ", - "", - ] - elif language == Language.HTML: - return [ - # First, try to split along HTML tags - " dict: + # ------------- + # overwrite tool parameter types for temp fix + tools = jsonable_encoder(self.tools) + for tool in tools: + if tool.get('parameters'): + for parameter in tool.get('parameters'): + if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value: + parameter['type'] = 'files' + # ------------- + return { 'id': self.id, 'author': self.author, @@ -47,7 +57,8 @@ class UserToolProvider(BaseModel): 'team_credentials': self.masked_credentials, 'is_team_authorization': self.is_team_authorization, 'allow_delete': self.allow_delete, - 'tools': self.tools + 'tools': tools, + 'labels': self.labels, } class UserToolProviderCredentials(BaseModel): diff --git a/api/core/tools/entities/constant.py b/api/core/tools/entities/constant.py deleted file mode 100644 index 2e75fedf99..0000000000 --- a/api/core/tools/entities/constant.py +++ /dev/null @@ -1,3 +0,0 @@ -class DEFAULT_PROVIDERS: - API_BASED = '__api_based' - APP_BASED = '__app_based' \ No newline at end of file diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index efa10e792c..d18d27fb02 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -1,11 +1,11 @@ -from typing import Any, Optional +from typing import Optional from pydantic import BaseModel -from core.tools.entities.tool_entities import ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ToolParameter -class ApiBasedToolBundle(BaseModel): +class ApiToolBundle(BaseModel): """ This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc. """ @@ -25,12 +25,3 @@ class ApiBasedToolBundle(BaseModel): icon: Optional[str] = None # openapi operation openapi: dict - -class AppToolBundle(BaseModel): - """ - This class is used to store the schema information of an tool for an app. - """ - type: ToolProviderType - credential: Optional[dict[str, Any]] = None - provider_id: str - tool_name: str \ No newline at end of file diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index fad91baf83..55ef8e8291 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -6,14 +6,33 @@ from pydantic import BaseModel, Field from core.tools.entities.common_entities import I18nObject +class ToolLabelEnum(Enum): + SEARCH = 'search' + IMAGE = 'image' + VIDEOS = 'videos' + WEATHER = 'weather' + FINANCE = 'finance' + DESIGN = 'design' + TRAVEL = 'travel' + SOCIAL = 'social' + NEWS = 'news' + MEDICAL = 'medical' + PRODUCTIVITY = 'productivity' + EDUCATION = 'education' + BUSINESS = 'business' + ENTERTAINMENT = 'entertainment' + UTILITIES = 'utilities' + OTHER = 'other' + class ToolProviderType(Enum): """ Enum class for tool provider """ - BUILT_IN = "built-in" + BUILT_IN = "builtin" + WORKFLOW = "workflow" + API = "api" + APP = "app" DATASET_RETRIEVAL = "dataset-retrieval" - APP_BASED = "app-based" - API_BASED = "api-based" @classmethod def value_of(cls, value: str) -> 'ToolProviderType': @@ -77,6 +96,7 @@ class ToolInvokeMessage(BaseModel): LINK = "link" BLOB = "blob" IMAGE_LINK = "image_link" + FILE_VAR = "file_var" type: MessageType = MessageType.TEXT """ @@ -90,18 +110,21 @@ class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") save_as: str = '' + file_var: Optional[dict[str, Any]] = None class ToolParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") + class ToolParameter(BaseModel): - class ToolParameterType(Enum): + class ToolParameterType(str, Enum): STRING = "string" NUMBER = "number" BOOLEAN = "boolean" SELECT = "select" SECRET_INPUT = "secret-input" + FILE = "file" class ToolParameterForm(Enum): SCHEMA = "schema" # should be set while adding tool @@ -153,6 +176,7 @@ class ToolProviderIdentity(BaseModel): description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") label: I18nObject = Field(..., description="The label of the tool") + tags: Optional[list[ToolLabelEnum]] = Field(default=[], description="The tags of the tool", ) class ToolDescription(BaseModel): human: I18nObject = Field(..., description="The description presented to the user") @@ -331,6 +355,15 @@ class ModelToolProviderConfiguration(BaseModel): models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") label: I18nObject = Field(..., description="The label of the model tool") + +class WorkflowToolParameterConfiguration(BaseModel): + """ + Workflow tool configuration + """ + name: str = Field(..., description="The name of the parameter") + description: str = Field(..., description="The description of the parameter") + form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") + class ToolInvokeMeta(BaseModel): """ Tool invoke meta @@ -358,4 +391,19 @@ class ToolInvokeMeta(BaseModel): 'time_cost': self.time_cost, 'error': self.error, 'tool_config': self.tool_config, - } \ No newline at end of file + } + +class ToolLabel(BaseModel): + """ + Tool label + """ + name: str = Field(..., description="The name of the tool") + label: I18nObject = Field(..., description="The label of the tool") + icon: str = Field(..., description="The icon of the tool") + +class ToolInvokeFrom(Enum): + """ + Enum class for tool invoke + """ + WORKFLOW = "workflow" + AGENT = "agent" \ No newline at end of file diff --git a/api/core/tools/entities/values.py b/api/core/tools/entities/values.py new file mode 100644 index 0000000000..d0be5e9355 --- /dev/null +++ b/api/core/tools/entities/values.py @@ -0,0 +1,75 @@ +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum + +ICONS = { + ToolLabelEnum.SEARCH: ''' + +''', + ToolLabelEnum.IMAGE: ''' + +''', + ToolLabelEnum.VIDEOS: ''' + +''', + ToolLabelEnum.WEATHER: ''' + +''', + ToolLabelEnum.FINANCE: ''' + +''', + ToolLabelEnum.DESIGN: ''' + +''', + ToolLabelEnum.TRAVEL: ''' + +''', + ToolLabelEnum.SOCIAL: ''' + +''', + ToolLabelEnum.NEWS: ''' + +''', + ToolLabelEnum.MEDICAL: ''' + +''', + ToolLabelEnum.PRODUCTIVITY: ''' + +''', + ToolLabelEnum.EDUCATION: ''' + +''', + ToolLabelEnum.BUSINESS: ''' + +''', + ToolLabelEnum.ENTERTAINMENT: ''' + +''', + ToolLabelEnum.UTILITIES: ''' + +''', + ToolLabelEnum.OTHER: ''' + +''' +} + +default_tool_label_dict = { + ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]), + ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]), + ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]), + ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]), + ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]), + ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]), + ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]), + ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]), + ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]), + ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]), + ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]), + ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]), + ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]), + ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]), + ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]), + ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]), +} + +default_tool_labels = [v for k, v in default_tool_label_dict.items()] +default_tool_label_name_list = [label.name for label in default_tool_labels] diff --git a/api/core/tools/model/errors.py b/api/core/tools/model/errors.py deleted file mode 100644 index 6e242b349a..0000000000 --- a/api/core/tools/model/errors.py +++ /dev/null @@ -1,2 +0,0 @@ -class InvokeModelError(Exception): - pass \ No newline at end of file diff --git a/api/core/tools/model_tools/anthropic.yaml b/api/core/tools/model_tools/anthropic.yaml deleted file mode 100644 index 4ccb973df5..0000000000 --- a/api/core/tools/model_tools/anthropic.yaml +++ /dev/null @@ -1,20 +0,0 @@ -provider: anthropic -label: - en_US: Anthropic Model Tools - zh_Hans: Anthropic 模型能力 - pt_BR: Anthropic Model Tools -models: - - type: llm - model: claude-3-sonnet-20240229 - label: - zh_Hans: Claude3 Sonnet 视觉 - en_US: Claude3 Sonnet Vision - properties: - image_parameter_name: image_id - - type: llm - model: claude-3-opus-20240229 - label: - zh_Hans: Claude3 Opus 视觉 - en_US: Claude3 Opus Vision - properties: - image_parameter_name: image_id diff --git a/api/core/tools/model_tools/google.yaml b/api/core/tools/model_tools/google.yaml deleted file mode 100644 index d81e1b0735..0000000000 --- a/api/core/tools/model_tools/google.yaml +++ /dev/null @@ -1,13 +0,0 @@ -provider: google -label: - en_US: Google Model Tools - zh_Hans: Google 模型能力 - pt_BR: Google Model Tools -models: - - type: llm - model: gemini-pro-vision - label: - zh_Hans: Gemini Pro 视觉 - en_US: Gemini Pro Vision - properties: - image_parameter_name: image_id diff --git a/api/core/tools/model_tools/openai.yaml b/api/core/tools/model_tools/openai.yaml deleted file mode 100644 index 45cbb295a9..0000000000 --- a/api/core/tools/model_tools/openai.yaml +++ /dev/null @@ -1,13 +0,0 @@ -provider: openai -label: - en_US: OpenAI Model Tools - zh_Hans: OpenAI 模型能力 - pt_BR: OpenAI Model Tools -models: - - type: llm - model: gpt-4-vision-preview - label: - zh_Hans: GPT-4 视觉 - en_US: GPT-4 Vision - properties: - image_parameter_name: image_id diff --git a/api/core/tools/model_tools/zhipuai.yaml b/api/core/tools/model_tools/zhipuai.yaml deleted file mode 100644 index 19a932eb89..0000000000 --- a/api/core/tools/model_tools/zhipuai.yaml +++ /dev/null @@ -1,13 +0,0 @@ -provider: zhipuai -label: - en_US: ZhipuAI Model Tools - zh_Hans: ZhipuAI 模型能力 - pt_BR: ZhipuAI Model Tools -models: - - type: llm - model: glm-4v - label: - zh_Hans: GLM-4 视觉 - en_US: GLM-4 Vision - properties: - image_parameter_name: image_id diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index 5c9454c11c..3b0f78cc76 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -1,21 +1,18 @@ - google - bing - duckduckgo +- searchapi - searxng - dalle - azuredalle - stability - wikipedia -- model.openai -- model.google -- model.anthropic - yahoo - arxiv - pubmed - stablediffusion - webscraper - jina -- model.zhipuai - aippt - youtube - code diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index 11e6e892c9..ae80ad2114 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,7 +1,6 @@ -from typing import Any from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolCredentialsOption, @@ -15,11 +14,11 @@ from extensions.ext_database import db from models.tools import ApiToolProvider -class ApiBasedToolProviderController(ToolProviderController): +class ApiToolProviderController(ToolProviderController): provider_id: str @staticmethod - def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController': + def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController': credentials_schema = { 'auth_type': ToolProviderCredentials( name='auth_type', @@ -79,9 +78,11 @@ class ApiBasedToolProviderController(ToolProviderController): else: raise ValueError(f'invalid auth type {auth_type}') - return ApiBasedToolProviderController(**{ + user_name = db_provider.user.name if db_provider.user_id else '' + + return ApiToolProviderController(**{ 'identity': { - 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', + 'author': user_name, 'name': db_provider.name, 'label': { 'en_US': db_provider.name, @@ -98,16 +99,10 @@ class ApiBasedToolProviderController(ToolProviderController): }) @property - def app_type(self) -> ToolProviderType: - return ToolProviderType.API_BASED - - def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: - pass + def provider_type(self) -> ToolProviderType: + return ToolProviderType.API - def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None: - pass - - def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool: + def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: """ parse tool bundle to tool @@ -136,7 +131,7 @@ class ApiBasedToolProviderController(ToolProviderController): 'parameters' : tool_bundle.parameters if tool_bundle.parameters else [], }) - def load_bundled_tools(self, tools: list[ApiBasedToolBundle]) -> list[ApiTool]: + def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: """ load bundled tools diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 159c94bbf3..2d472e0a93 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -11,10 +11,10 @@ from models.tools import PublishedAppTool logger = logging.getLogger(__name__) -class AppBasedToolProviderEntity(ToolProviderController): +class AppToolProviderEntity(ToolProviderController): @property - def app_type(self) -> ToolProviderType: - return ToolProviderType.APP_BASED + def provider_type(self) -> ToolProviderType: + return ToolProviderType.APP def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: pass diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index 2bf70bd356..ae806eaff4 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,7 +1,7 @@ import os.path -from core.tools.entities.user_entities import UserToolProvider -from core.utils.position_helper import get_position_map, sort_by_position_map +from core.helper.position_helper import get_position_map, sort_by_position_map +from core.tools.entities.api_entities import UserToolProvider class BuiltinToolProviderSort: @@ -13,10 +13,7 @@ class BuiltinToolProviderSort: cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) def name_func(provider: UserToolProvider) -> str: - if provider.type == UserToolProvider.ProviderType.MODEL: - return f'model.{provider.name}' - else: - return provider.name + return provider.name sorted_providers = sort_by_position_map(cls._position, providers, name_func) diff --git a/api/core/tools/provider/builtin/aippt/aippt.yaml b/api/core/tools/provider/builtin/aippt/aippt.yaml index b3ff1f6d98..9b1b45d0f2 100644 --- a/api/core/tools/provider/builtin/aippt/aippt.yaml +++ b/api/core/tools/provider/builtin/aippt/aippt.yaml @@ -8,6 +8,9 @@ identity: en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底 icon: icon.png + tags: + - productivity + - design credentials_for_provider: aippt_access_key: type: secret-input diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.py b/api/core/tools/provider/builtin/arxiv/arxiv.py index 998128522e..707fc69be3 100644 --- a/api/core/tools/provider/builtin/arxiv/arxiv.py +++ b/api/core/tools/provider/builtin/arxiv/arxiv.py @@ -7,7 +7,7 @@ class ArxivProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: ArxivSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -17,4 +17,5 @@ class ArxivProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.yaml b/api/core/tools/provider/builtin/arxiv/arxiv.yaml index 78c5c161af..d26993b336 100644 --- a/api/core/tools/provider/builtin/arxiv/arxiv.yaml +++ b/api/core/tools/provider/builtin/arxiv/arxiv.yaml @@ -8,3 +8,5 @@ identity: en_US: Access to a vast repository of scientific papers and articles in various fields of research. zh_Hans: 访问各个研究领域大量科学论文和文章的存储库。 icon: icon.svg + tags: + - search diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index fb64e07a8c..448d2e8f84 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -68,7 +68,7 @@ class ArxivAPIWrapper(BaseModel): Args: query: a plaintext search query - """ # noqa: E501 + """ try: results = self.arxiv_search( # type: ignore query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results diff --git a/api/core/tools/provider/builtin/azuredalle/azuredalle.py b/api/core/tools/provider/builtin/azuredalle/azuredalle.py index 4278da54ba..2981a54d3c 100644 --- a/api/core/tools/provider/builtin/azuredalle/azuredalle.py +++ b/api/core/tools/provider/builtin/azuredalle/azuredalle.py @@ -9,7 +9,7 @@ class AzureDALLEProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: DallE3Tool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( diff --git a/api/core/tools/provider/builtin/azuredalle/azuredalle.yaml b/api/core/tools/provider/builtin/azuredalle/azuredalle.yaml index 62ce2c21fe..4353e0c486 100644 --- a/api/core/tools/provider/builtin/azuredalle/azuredalle.yaml +++ b/api/core/tools/provider/builtin/azuredalle/azuredalle.yaml @@ -10,6 +10,9 @@ identity: zh_Hans: Azure DALL-E 绘画 pt_BR: Azure DALL-E art icon: icon.png + tags: + - image + - productivity credentials_for_provider: azure_openai_api_key: type: secret-input diff --git a/api/core/tools/provider/builtin/bing/bing.py b/api/core/tools/provider/builtin/bing/bing.py index 6e62abfc10..c71128be4a 100644 --- a/api/core/tools/provider/builtin/bing/bing.py +++ b/api/core/tools/provider/builtin/bing/bing.py @@ -9,7 +9,7 @@ class BingProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: BingSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).validate_credentials( diff --git a/api/core/tools/provider/builtin/bing/bing.yaml b/api/core/tools/provider/builtin/bing/bing.yaml index 35cd729208..1ab17d5294 100644 --- a/api/core/tools/provider/builtin/bing/bing.yaml +++ b/api/core/tools/provider/builtin/bing/bing.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: Bing 搜索 pt_BR: Bing Search icon: icon.svg + tags: + - search credentials_for_provider: subscription_key: type: secret-input diff --git a/api/core/tools/provider/builtin/brave/brave.py b/api/core/tools/provider/builtin/brave/brave.py index e26b28b46a..e5eada80ee 100644 --- a/api/core/tools/provider/builtin/brave/brave.py +++ b/api/core/tools/provider/builtin/brave/brave.py @@ -9,7 +9,7 @@ class BraveProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: BraveSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -19,4 +19,5 @@ class BraveProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/brave/brave.yaml b/api/core/tools/provider/builtin/brave/brave.yaml index d1b7ff1086..93d315f839 100644 --- a/api/core/tools/provider/builtin/brave/brave.yaml +++ b/api/core/tools/provider/builtin/brave/brave.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: Brave pt_BR: Brave icon: icon.svg + tags: + - search credentials_for_provider: brave_search_api_key: type: secret-input diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index f5e42e766d..0865bc700a 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -44,7 +44,7 @@ class ChartProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: LinearChartTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -54,4 +54,5 @@ class ChartProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/chart.yaml b/api/core/tools/provider/builtin/chart/chart.yaml index 7aac32b3bb..ad0d9a6cd6 100644 --- a/api/core/tools/provider/builtin/chart/chart.yaml +++ b/api/core/tools/provider/builtin/chart/chart.yaml @@ -10,4 +10,8 @@ identity: zh_Hans: 图表生成是一个用于生成可视化图表的工具,你可以通过它来生成柱状图、折线图、饼图等各类图表 pt_BR: O Gerador de gráficos é uma ferramenta para gerar gráficos estatísticos como gráfico de barras, gráfico de linhas, gráfico de pizza, etc. icon: icon.png + tags: + - design + - productivity + - utilities credentials_for_provider: diff --git a/api/core/tools/provider/builtin/chart/tools/bar.yaml b/api/core/tools/provider/builtin/chart/tools/bar.yaml index bc34f2a5ec..ee7405f681 100644 --- a/api/core/tools/provider/builtin/chart/tools/bar.yaml +++ b/api/core/tools/provider/builtin/chart/tools/bar.yaml @@ -21,9 +21,9 @@ parameters: zh_Hans: 数据 pt_BR: dados human_description: - en_US: data for generating bar chart - zh_Hans: 用于生成柱状图的数据 - pt_BR: dados para gerar gráfico de barras + en_US: data for generating chart, each number should be separated by ";" + zh_Hans: 用于生成柱状图的数据,每个数字之间用 ";" 分隔 + pt_BR: dados para gerar gráfico de barras, cada número deve ser separado por ";" llm_description: data for generating bar chart, data should be a string contains a list of numbers like "1;2;3;4;5" form: llm - name: x_axis @@ -34,8 +34,8 @@ parameters: zh_Hans: x 轴 pt_BR: Eixo X human_description: - en_US: X axis for bar chart - zh_Hans: 柱状图的 x 轴 - pt_BR: Eixo X para gráfico de barras + en_US: X axis for chart, each text should be separated by ";" + zh_Hans: 柱状图的 x 轴,每个文本之间用 ";" 分隔 + pt_BR: Eixo X para gráfico de barras, cada texto deve ser separado por ";" llm_description: x axis for bar chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data form: llm diff --git a/api/core/tools/provider/builtin/chart/tools/line.yaml b/api/core/tools/provider/builtin/chart/tools/line.yaml index 9994cbb80b..35ebe3b68b 100644 --- a/api/core/tools/provider/builtin/chart/tools/line.yaml +++ b/api/core/tools/provider/builtin/chart/tools/line.yaml @@ -21,9 +21,9 @@ parameters: zh_Hans: 数据 pt_BR: dados human_description: - en_US: data for generating linear chart - zh_Hans: 用于生成线性图表的数据 - pt_BR: dados para gerar gráfico linear + en_US: data for generating chart, each number should be separated by ";" + zh_Hans: 用于生成线性图表的数据,每个数字之间用 ";" 分隔 + pt_BR: dados para gerar gráfico linear, cada número deve ser separado por ";" llm_description: data for generating linear chart, data should be a string contains a list of numbers like "1;2;3;4;5" form: llm - name: x_axis @@ -34,8 +34,8 @@ parameters: zh_Hans: x 轴 pt_BR: Eixo X human_description: - en_US: X axis for linear chart - zh_Hans: 线性图表的 x 轴 - pt_BR: Eixo X para gráfico linear + en_US: X axis for chart, each text should be separated by ";" + zh_Hans: 线性图表的 x 轴,每个文本之间用 ";" 分隔 + pt_BR: Eixo X para gráfico linear, cada texto deve ser separado por ";" llm_description: x axis for linear chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data form: llm diff --git a/api/core/tools/provider/builtin/chart/tools/pie.yaml b/api/core/tools/provider/builtin/chart/tools/pie.yaml index 7d6647ef4a..541715cb7d 100644 --- a/api/core/tools/provider/builtin/chart/tools/pie.yaml +++ b/api/core/tools/provider/builtin/chart/tools/pie.yaml @@ -21,9 +21,9 @@ parameters: zh_Hans: 数据 pt_BR: dados human_description: - en_US: data for generating pie chart - zh_Hans: 用于生成饼图的数据 - pt_BR: dados para gerar gráfico de pizza + en_US: data for generating chart, each number should be separated by ";" + zh_Hans: 用于生成饼图的数据,每个数字之间用 ";" 分隔 + pt_BR: dados para gerar gráfico de pizza, cada número deve ser separado por ";" llm_description: data for generating pie chart, data should be a string contains a list of numbers like "1;2;3;4;5" form: llm - name: categories @@ -34,8 +34,8 @@ parameters: zh_Hans: 分类 pt_BR: Categorias human_description: - en_US: Categories for pie chart - zh_Hans: 饼图的分类 - pt_BR: Categorias para gráfico de pizza + en_US: Categories for chart, each category should be separated by ";" + zh_Hans: 饼图的分类,每个分类之间用 ";" 分隔 + pt_BR: Categorias para gráfico de pizza, cada categoria deve ser separada por ";" llm_description: categories for pie chart, categories should be a string contains a list of texts like "a;b;c;1;2" in order to match the data, each category should be split by ";" form: llm diff --git a/api/core/tools/provider/builtin/code/code.py b/api/core/tools/provider/builtin/code/code.py index fae5ecf769..211417c9a4 100644 --- a/api/core/tools/provider/builtin/code/code.py +++ b/api/core/tools/provider/builtin/code/code.py @@ -5,4 +5,4 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class CodeToolProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/tools/provider/builtin/code/code.yaml b/api/core/tools/provider/builtin/code/code.yaml index b0fd0dd587..2640a7087e 100644 --- a/api/core/tools/provider/builtin/code/code.yaml +++ b/api/core/tools/provider/builtin/code/code.yaml @@ -10,4 +10,6 @@ identity: zh_Hans: 运行一段代码并返回结果。 pt_BR: Execute um trecho de código e obtenha o resultado de volta. icon: icon.svg + tags: + - productivity credentials_for_provider: diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py index ae9b1cb612..37645bf0d0 100644 --- a/api/core/tools/provider/builtin/code/tools/simple_code.py +++ b/api/core/tools/provider/builtin/code/tools/simple_code.py @@ -1,6 +1,6 @@ from typing import Any -from core.helper.code_executor.code_executor import CodeExecutor +from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -11,10 +11,10 @@ class SimpleCode(BuiltinTool): invoke simple code """ - language = tool_parameters.get('language', 'python3') + language = tool_parameters.get('language', CodeLanguage.PYTHON3) code = tool_parameters.get('code', '') - if language not in ['python3', 'javascript']: + if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]: raise ValueError(f'Only python3 and javascript are supported, not {language}') result = CodeExecutor.execute_code(language, '', code) diff --git a/api/core/tools/provider/builtin/dalle/dalle.py b/api/core/tools/provider/builtin/dalle/dalle.py index 34a24a7425..1c8019364d 100644 --- a/api/core/tools/provider/builtin/dalle/dalle.py +++ b/api/core/tools/provider/builtin/dalle/dalle.py @@ -9,7 +9,7 @@ class DALLEProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: DallE2Tool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -21,4 +21,5 @@ class DALLEProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/dalle/dalle.yaml b/api/core/tools/provider/builtin/dalle/dalle.yaml index a99401d82d..f09a9177f2 100644 --- a/api/core/tools/provider/builtin/dalle/dalle.yaml +++ b/api/core/tools/provider/builtin/dalle/dalle.yaml @@ -10,6 +10,9 @@ identity: zh_Hans: DALL-E 绘画 pt_BR: DALL-E art icon: icon.png + tags: + - image + - productivity credentials_for_provider: openai_api_key: type: secret-input diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py index e41cbd9f65..450e782281 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -1,8 +1,8 @@ from base64 import b64decode -from os.path import join from typing import Any, Union from openai import OpenAI +from yarl import URL from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -23,7 +23,7 @@ class DallE2Tool(BuiltinTool): if not openai_base_url: openai_base_url = None else: - openai_base_url = join(openai_base_url, 'v1') + openai_base_url = str(URL(openai_base_url) / 'v1') client = OpenAI( api_key=self.runtime.credentials['openai_api_key'], diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index dc53025b02..87d18f68e0 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -1,8 +1,8 @@ from base64 import b64decode -from os.path import join from typing import Any, Union from openai import OpenAI +from yarl import URL from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -23,7 +23,7 @@ class DallE3Tool(BuiltinTool): if not openai_base_url: openai_base_url = None else: - openai_base_url = join(openai_base_url, 'v1') + openai_base_url = str(URL(openai_base_url) / 'v1') client = OpenAI( api_key=self.runtime.credentials['openai_api_key'], diff --git a/api/core/tools/provider/builtin/devdocs/devdocs.py b/api/core/tools/provider/builtin/devdocs/devdocs.py index 25cbe4d053..95d7939d0d 100644 --- a/api/core/tools/provider/builtin/devdocs/devdocs.py +++ b/api/core/tools/provider/builtin/devdocs/devdocs.py @@ -7,7 +7,7 @@ class DevDocsProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: SearchDevDocsTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -18,4 +18,5 @@ class DevDocsProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/devdocs/devdocs.yaml b/api/core/tools/provider/builtin/devdocs/devdocs.yaml index 1db226fc4b..7552f5a497 100644 --- a/api/core/tools/provider/builtin/devdocs/devdocs.yaml +++ b/api/core/tools/provider/builtin/devdocs/devdocs.yaml @@ -8,3 +8,6 @@ identity: en_US: Get official developer documentations on DevDocs. zh_Hans: 从DevDocs获取官方开发者文档。 icon: icon.svg + tags: + - search + - productivity diff --git a/api/core/tools/provider/builtin/dingtalk/dingtalk.yaml b/api/core/tools/provider/builtin/dingtalk/dingtalk.yaml index ebe2e4fbaf..c922c140a8 100644 --- a/api/core/tools/provider/builtin/dingtalk/dingtalk.yaml +++ b/api/core/tools/provider/builtin/dingtalk/dingtalk.yaml @@ -10,4 +10,7 @@ identity: zh_Hans: 钉钉群机器人 pt_BR: DingTalk group robot icon: icon.svg + tags: + - social + - productivity credentials_for_provider: diff --git a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py index 3e9b57ece7..6df8678d30 100644 --- a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py +++ b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py @@ -7,7 +7,7 @@ class DuckDuckGoProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: DuckDuckGoSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -17,4 +17,5 @@ class DuckDuckGoProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.yaml b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.yaml index 8778dde625..f3faa06045 100644 --- a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.yaml @@ -8,3 +8,5 @@ identity: en_US: A privacy-focused search engine. zh_Hans: 一个注重隐私的搜索引擎。 icon: icon.svg + tags: + - search diff --git a/api/core/tools/provider/builtin/feishu/feishu.py b/api/core/tools/provider/builtin/feishu/feishu.py index 13303dbe64..72a9333619 100644 --- a/api/core/tools/provider/builtin/feishu/feishu.py +++ b/api/core/tools/provider/builtin/feishu/feishu.py @@ -5,4 +5,3 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class FeishuProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: FeishuGroupBotTool() - pass diff --git a/api/core/tools/provider/builtin/feishu/feishu.yaml b/api/core/tools/provider/builtin/feishu/feishu.yaml index a1fcd38047..a029c7edb8 100644 --- a/api/core/tools/provider/builtin/feishu/feishu.yaml +++ b/api/core/tools/provider/builtin/feishu/feishu.yaml @@ -10,4 +10,7 @@ identity: zh_Hans: 飞书群机器人 pt_BR: Feishu group bot icon: icon.svg + tags: + - social + - productivity credentials_for_provider: diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.py b/api/core/tools/provider/builtin/firecrawl/firecrawl.py index 20ab978b8d..adcb7ebdd6 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.py @@ -8,7 +8,7 @@ class FirecrawlProvider(BuiltinToolProviderController): try: # Example validation using the Crawl tool CrawlTool().fork_tool_runtime( - meta={"credentials": credentials} + runtime={"credentials": credentials} ).invoke( user_id='', tool_parameters={ @@ -20,4 +20,5 @@ class FirecrawlProvider(BuiltinToolProviderController): } ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml b/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml index 67fad19159..311283dcb5 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.yaml @@ -8,6 +8,9 @@ identity: en_US: Firecrawl API integration for web crawling and scraping. zh_CN: Firecrawl API 集成,用于网页爬取和数据抓取。 icon: icon.svg + tags: + - search + - utilities credentials_for_provider: firecrawl_api_key: type: secret-input diff --git a/api/core/tools/provider/builtin/gaode/gaode.yaml b/api/core/tools/provider/builtin/gaode/gaode.yaml index bca53b22e9..2eb3b161a2 100644 --- a/api/core/tools/provider/builtin/gaode/gaode.yaml +++ b/api/core/tools/provider/builtin/gaode/gaode.yaml @@ -10,6 +10,11 @@ identity: zh_Hans: 高德开放平台服务工具包。 pt_BR: Kit de ferramentas de serviço Autonavi Open Platform. icon: icon.svg + tags: + - utilities + - productivity + - travel + - weather credentials_for_provider: api_key: type: secret-input diff --git a/api/core/tools/provider/builtin/github/github.yaml b/api/core/tools/provider/builtin/github/github.yaml index d529e639cc..c3d85fc3f6 100644 --- a/api/core/tools/provider/builtin/github/github.yaml +++ b/api/core/tools/provider/builtin/github/github.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: GitHub是一个在线软件源代码托管服务平台。 pt_BR: GitHub é uma plataforma online para serviços de hospedagem de código fonte de software. icon: icon.svg + tags: + - utilities credentials_for_provider: access_tokens: type: secret-input diff --git a/api/core/tools/provider/builtin/google/google.py b/api/core/tools/provider/builtin/google/google.py index 3900804b45..8f4b9a4a4e 100644 --- a/api/core/tools/provider/builtin/google/google.py +++ b/api/core/tools/provider/builtin/google/google.py @@ -9,7 +9,7 @@ class GoogleProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: GoogleSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -20,4 +20,5 @@ class GoogleProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/google.yaml b/api/core/tools/provider/builtin/google/google.yaml index 43b75b51cd..afb4d5b214 100644 --- a/api/core/tools/provider/builtin/google/google.yaml +++ b/api/core/tools/provider/builtin/google/google.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: GoogleSearch pt_BR: Google icon: icon.svg + tags: + - search credentials_for_provider: serpapi_api_key: type: secret-input diff --git a/api/core/tools/provider/builtin/google/tools/google_search.py b/api/core/tools/provider/builtin/google/tools/google_search.py index 0b1978ad3e..87c2cc5796 100644 --- a/api/core/tools/provider/builtin/google/tools/google_search.py +++ b/api/core/tools/provider/builtin/google/tools/google_search.py @@ -1,39 +1,20 @@ -import os -import sys from typing import Any, Union -from serpapi import GoogleSearch +import requests from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool - -class HiddenPrints: - """Context manager to hide prints.""" - - def __enter__(self) -> None: - """Open file to pipe stdout to.""" - self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, "w") - - def __exit__(self, *_: Any) -> None: - """Close file that stdout was piped to.""" - sys.stdout.close() - sys.stdout = self._original_stdout +SERP_API_URL = "https://serpapi.com/search" class SerpAPI: """ SerpAPI tool provider. """ - - search_engine: Any #: :meta private: - serpapi_api_key: str = None - def __init__(self, api_key: str) -> None: """Initialize SerpAPI tool provider.""" self.serpapi_api_key = api_key - self.search_engine = GoogleSearch def run(self, query: str, **kwargs: Any) -> str: """Run query through SerpAPI and parse result.""" @@ -43,114 +24,76 @@ class SerpAPI: def results(self, query: str) -> dict: """Run query through SerpAPI and return the raw result.""" params = self.get_params(query) - with HiddenPrints(): - search = self.search_engine(params) - res = search.get_dict() - return res + response = requests.get(url=SERP_API_URL, params=params) + response.raise_for_status() + return response.json() def get_params(self, query: str) -> dict[str, str]: """Get parameters for SerpAPI.""" - _params = { + params = { "api_key": self.serpapi_api_key, "q": query, - } - params = { "engine": "google", "google_domain": "google.com", "gl": "us", - "hl": "en", - **_params + "hl": "en" } return params @staticmethod def _process_response(res: dict, typ: str) -> str: - """Process response from SerpAPI.""" - if "error" in res.keys(): + """ + Process response from SerpAPI. + SerpAPI doc: https://serpapi.com/search-api + Google search main results are called organic results + """ + if "error" in res: raise ValueError(f"Got error from SerpAPI: {res['error']}") - + toret = "" if typ == "text": - toret = "" - if "answer_box" in res.keys() and type(res["answer_box"]) == list: - res["answer_box"] = res["answer_box"][0] + "\n" - if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): - toret += res["answer_box"]["answer"] + "\n" - if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): - toret += res["answer_box"]["snippet"] + "\n" - if ( - "answer_box" in res.keys() - and "snippet_highlighted_words" in res["answer_box"].keys() - ): - for item in res["answer_box"]["snippet_highlighted_words"]: - toret += item + "\n" - if ( - "sports_results" in res.keys() - and "game_spotlight" in res["sports_results"].keys() - ): - toret += res["sports_results"]["game_spotlight"] + "\n" - if ( - "shopping_results" in res.keys() - and "title" in res["shopping_results"][0].keys() - ): - toret += res["shopping_results"][:3] + "\n" - if ( - "knowledge_graph" in res.keys() - and "description" in res["knowledge_graph"].keys() - ): - toret = res["knowledge_graph"]["description"] + "\n" - if "snippet" in res["organic_results"][0].keys(): - for item in res["organic_results"]: - toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" - if ( - "images_results" in res.keys() - and "thumbnail" in res["images_results"][0].keys() - ): - thumbnails = [item["thumbnail"] for item in res["images_results"][:10]] - toret = thumbnails - if toret == "": - toret = "No good search result found" + if "knowledge_graph" in res and "description" in res["knowledge_graph"]: + toret += res["knowledge_graph"]["description"] + "\n" + if "organic_results" in res: + snippets = [ + f"content: {item.get('snippet')}\nlink: {item.get('link')}" + for item in res["organic_results"] + if "snippet" in item + ] + toret += "\n".join(snippets) elif typ == "link": - if "knowledge_graph" in res.keys() and "title" in res["knowledge_graph"].keys() \ - and "description_link" in res["knowledge_graph"].keys(): - toret = res["knowledge_graph"]["description_link"] - elif "knowledge_graph" in res.keys() and "see_results_about" in res["knowledge_graph"].keys() \ - and len(res["knowledge_graph"]["see_results_about"]) > 0: - see_result_about = res["knowledge_graph"]["see_results_about"] - toret = "" - for item in see_result_about: - if "name" not in item.keys() or "link" not in item.keys(): - continue - toret += f"[{item['name']}]({item['link']})\n" - elif "organic_results" in res.keys() and len(res["organic_results"]) > 0: - organic_results = res["organic_results"] - toret = "" - for item in organic_results: - if "title" not in item.keys() or "link" not in item.keys(): - continue - toret += f"[{item['title']}]({item['link']})\n" - elif "related_questions" in res.keys() and len(res["related_questions"]) > 0: - related_questions = res["related_questions"] - toret = "" - for item in related_questions: - if "question" not in item.keys() or "link" not in item.keys(): - continue - toret += f"[{item['question']}]({item['link']})\n" - elif "related_searches" in res.keys() and len(res["related_searches"]) > 0: - related_searches = res["related_searches"] - toret = "" - for item in related_searches: - if "query" not in item.keys() or "link" not in item.keys(): - continue - toret += f"[{item['query']}]({item['link']})\n" - else: - toret = "No good search result found" + if "knowledge_graph" in res and "source" in res["knowledge_graph"]: + toret += res["knowledge_graph"]["source"]["link"] + elif "organic_results" in res: + links = [ + f"[{item['title']}]({item['link']})\n" + for item in res["organic_results"] + if "title" in item and "link" in item + ] + toret += "\n".join(links) + elif "related_questions" in res: + questions = [ + f"[{item['question']}]({item['link']})\n" + for item in res["related_questions"] + if "question" in item and "link" in item + ] + toret += "\n".join(questions) + elif "related_searches" in res: + searches = [ + f"[{item['query']}]({item['link']})\n" + for item in res["related_searches"] + if "query" in item and "link" in item + ] + toret += "\n".join(searches) + if not toret: + toret = "No good search result found" return toret + class GoogleSearchTool(BuiltinTool): - def _invoke(self, + def _invoke(self, user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools """ @@ -161,4 +104,3 @@ class GoogleSearchTool(BuiltinTool): if result_type == 'text': return self.create_text_message(text=result) return self.create_link_message(link=result) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/jina/jina.py b/api/core/tools/provider/builtin/jina/jina.py index ed1de6f6c1..b1a8d62138 100644 --- a/api/core/tools/provider/builtin/jina/jina.py +++ b/api/core/tools/provider/builtin/jina/jina.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,4 +10,9 @@ class GoogleProvider(BuiltinToolProviderController): try: pass except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/jina/jina.yaml b/api/core/tools/provider/builtin/jina/jina.yaml index 6ae3330f40..67ad32a47a 100644 --- a/api/core/tools/provider/builtin/jina/jina.yaml +++ b/api/core/tools/provider/builtin/jina/jina.yaml @@ -2,12 +2,15 @@ identity: author: Dify name: jina label: - en_US: JinaReader - zh_Hans: JinaReader - pt_BR: JinaReader + en_US: Jina + zh_Hans: Jina + pt_BR: Jina description: - en_US: Convert any URL to an LLM-friendly input. Experience improved output for your agent and RAG systems at no cost. - zh_Hans: 将任何 URL 转换为 LLM 友好的输入。无需付费即可体验为您的 Agent 和 RAG 系统提供的改进输出。 - pt_BR: Converta qualquer URL em uma entrada amigável ao LLM. Experimente uma saída aprimorada para seus sistemas de agente e RAG sem custo. + en_US: Convert any URL to an LLM-friendly input or perform searches on the web for grounding information. Experience improved output for your agent and RAG systems at no cost. + zh_Hans: 将任何URL转换为LLM易读的输入或在网页上搜索引擎上搜索引擎。 + pt_BR: Converte qualquer URL em uma entrada LLm-fácil de ler ou realize pesquisas na web para obter informação de grounding. Tenha uma experiência melhor para seu agente e sistemas RAG sem custo. icon: icon.svg + tags: + - search + - productivity credentials_for_provider: diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index fd29a00aa5..beb05717ea 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -23,6 +23,14 @@ class JinaReaderTool(BuiltinTool): 'Accept': 'application/json' } + target_selector = tool_parameters.get('target_selector', None) + if target_selector is not None: + headers['X-Target-Selector'] = target_selector + + wait_for_selector = tool_parameters.get('wait_for_selector', None) + if wait_for_selector is not None: + headers['X-Wait-For-Selector'] = wait_for_selector + response = ssrf_proxy.get( str(URL(self._jina_reader_endpoint + url)), headers=headers, diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml index 38d66292df..73cacb7fde 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml @@ -25,6 +25,32 @@ parameters: pt_BR: used for linking to webpages llm_description: url for scraping form: llm + - name: target_selector + type: string + required: false + label: + en_US: Target selector + zh_Hans: 目标选择器 + pt_BR: Seletor de destino + human_description: + en_US: css selector for scraping specific elements + zh_Hans: css 选择器用于抓取特定元素 + pt_BR: css selector for scraping specific elements + llm_description: css selector of the target element to scrape + form: form + - name: wait_for_selector + type: string + required: false + label: + en_US: Wait for selector + zh_Hans: 等待选择器 + pt_BR: Aguardar por seletor + human_description: + en_US: css selector for waiting for specific elements + zh_Hans: css 选择器用于等待特定元素 + pt_BR: css selector for waiting for specific elements + llm_description: css selector of the target element to wait for + form: form - name: summary type: boolean required: false diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.py b/api/core/tools/provider/builtin/jina/tools/jina_search.py new file mode 100644 index 0000000000..cfe36e6a3c --- /dev/null +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.py @@ -0,0 +1,30 @@ +from typing import Any, Union + +from yarl import URL + +from core.helper import ssrf_proxy +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class JinaSearchTool(BuiltinTool): + _jina_search_endpoint = 'https://s.jina.ai/' + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters['query'] + + headers = { + 'Accept': 'application/json' + } + + response = ssrf_proxy.get( + str(URL(self._jina_search_endpoint + query)), + headers=headers, + timeout=(10, 60) + ) + + return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.yaml b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml new file mode 100644 index 0000000000..5ad70c03f3 --- /dev/null +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml @@ -0,0 +1,21 @@ +identity: + name: jina_search + author: Dify + label: + en_US: JinaSearch + zh_Hans: JinaSearch + pt_BR: JinaSearch +description: + human: + en_US: Search on the web and get the top 5 results. Useful for grounding using information from the web. + llm: A tool for searching results on the web for grounding. Input should be a simple question. +parameters: + - name: query + type: string + required: true + label: + en_US: Question (Query) + human_description: + en_US: used to find information on the web + llm_description: simple question to ask on the web + form: llm diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.py b/api/core/tools/provider/builtin/judge0ce/judge0ce.py index c00747868b..bac6576797 100644 --- a/api/core/tools/provider/builtin/judge0ce/judge0ce.py +++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.py @@ -9,7 +9,7 @@ class Judge0CEProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: ExecuteCodeTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -20,4 +20,5 @@ class Judge0CEProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.yaml b/api/core/tools/provider/builtin/judge0ce/judge0ce.yaml index 5f0a471827..9ff8aaac6d 100644 --- a/api/core/tools/provider/builtin/judge0ce/judge0ce.yaml +++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.yaml @@ -10,6 +10,9 @@ identity: zh_Hans: Judge0 CE 是一个开源的代码执行系统。支持多种语言,包括 C、C++、Java、Python、Ruby 等。 pt_BR: Judge0 CE é um sistema de execução de código de código aberto. Suporta várias linguagens, incluindo C, C++, Java, Python, Ruby, etc. icon: icon.svg + tags: + - utilities + - other credentials_for_provider: X-RapidAPI-Key: type: secret-input diff --git a/api/core/tools/provider/builtin/maths/maths.yaml b/api/core/tools/provider/builtin/maths/maths.yaml index 5bd892a927..35c2380e29 100644 --- a/api/core/tools/provider/builtin/maths/maths.yaml +++ b/api/core/tools/provider/builtin/maths/maths.yaml @@ -10,3 +10,6 @@ identity: zh_Hans: 一个用于数学计算的工具。 pt_BR: A tool for maths. icon: icon.svg + tags: + - utilities + - productivity diff --git a/api/core/tools/provider/builtin/openweather/openweather.yaml b/api/core/tools/provider/builtin/openweather/openweather.yaml index 60bb33c36d..d4b66f87f9 100644 --- a/api/core/tools/provider/builtin/openweather/openweather.yaml +++ b/api/core/tools/provider/builtin/openweather/openweather.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: 基于open weather的天气查询工具包 pt_BR: Kit de consulta de clima baseado no Open Weather icon: icon.svg + tags: + - weather credentials_for_provider: api_key: type: secret-input diff --git a/api/core/tools/provider/builtin/pubmed/pubmed.py b/api/core/tools/provider/builtin/pubmed/pubmed.py index 663617c0c1..05cd171b87 100644 --- a/api/core/tools/provider/builtin/pubmed/pubmed.py +++ b/api/core/tools/provider/builtin/pubmed/pubmed.py @@ -7,7 +7,7 @@ class PubMedProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: PubMedSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -17,4 +17,5 @@ class PubMedProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/pubmed/pubmed.yaml b/api/core/tools/provider/builtin/pubmed/pubmed.yaml index 971a6fb204..5f8303147c 100644 --- a/api/core/tools/provider/builtin/pubmed/pubmed.yaml +++ b/api/core/tools/provider/builtin/pubmed/pubmed.yaml @@ -8,3 +8,6 @@ identity: en_US: A search engine for biomedical literature. zh_Hans: 一款生物医学文献搜索引擎。 icon: icon.svg + tags: + - medical + - search diff --git a/api/core/tools/provider/builtin/qrcode/qrcode.yaml b/api/core/tools/provider/builtin/qrcode/qrcode.yaml index c117c3de74..82e2a06e15 100644 --- a/api/core/tools/provider/builtin/qrcode/qrcode.yaml +++ b/api/core/tools/provider/builtin/qrcode/qrcode.yaml @@ -10,3 +10,5 @@ identity: zh_Hans: 一个二维码工具 pt_BR: A tool for generating QR code (quick-response code) image. icon: icon.svg + tags: + - utilities diff --git a/api/core/tools/provider/builtin/searchapi/_assets/icon.svg b/api/core/tools/provider/builtin/searchapi/_assets/icon.svg new file mode 100644 index 0000000000..7660b2f351 --- /dev/null +++ b/api/core/tools/provider/builtin/searchapi/_assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/searchapi/searchapi.py b/api/core/tools/provider/builtin/searchapi/searchapi.py new file mode 100644 index 0000000000..6fa4f05acd --- /dev/null +++ b/api/core/tools/provider/builtin/searchapi/searchapi.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.searchapi.tools.google import GoogleTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class SearchAPIProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + GoogleTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "query": "SearchApi dify", + "result_type": "link" + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searchapi/searchapi.yaml b/api/core/tools/provider/builtin/searchapi/searchapi.yaml new file mode 100644 index 0000000000..c2fa3f398e --- /dev/null +++ b/api/core/tools/provider/builtin/searchapi/searchapi.yaml @@ -0,0 +1,34 @@ +identity: + author: SearchApi + name: searchapi + label: + en_US: SearchApi + zh_Hans: SearchApi + pt_BR: SearchApi + description: + en_US: SearchApi is a robust real-time SERP API delivering structured data from a collection of search engines including Google Search, Google Jobs, YouTube, Google News, and many more. + zh_Hans: SearchApi 是一个强大的实时 SERP API,可提供来自 Google 搜索、Google 招聘、YouTube、Google 新闻等搜索引擎集合的结构化数据。 + pt_BR: SearchApi is a robust real-time SERP API delivering structured data from a collection of search engines including Google Search, Google Jobs, YouTube, Google News, and many more. + icon: icon.svg + tags: + - search + - business + - news + - productivity +credentials_for_provider: + searchapi_api_key: + type: secret-input + required: true + label: + en_US: SearchApi API key + zh_Hans: SearchApi API key + pt_BR: SearchApi API key + placeholder: + en_US: Please input your SearchApi API key + zh_Hans: 请输入你的 SearchApi API key + pt_BR: Please input your SearchApi API key + help: + en_US: Get your SearchApi API key from SearchApi + zh_Hans: 从 SearchApi 获取您的 SearchApi API key + pt_BR: Get your SearchApi API key from SearchApi + url: https://www.searchapi.io/ diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py new file mode 100644 index 0000000000..d019fe7134 --- /dev/null +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -0,0 +1,104 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + +class SearchAPI: + """ + SearchAPI tool provider. + """ + + def __init__(self, api_key: str) -> None: + """Initialize SearchAPI tool provider.""" + self.searchapi_api_key = api_key + + def run(self, query: str, **kwargs: Any) -> str: + """Run query through SearchAPI and parse result.""" + type = kwargs.get("result_type", "text") + return self._process_response(self.results(query, **kwargs), type=type) + + def results(self, query: str, **kwargs: Any) -> dict: + """Run query through SearchAPI and return the raw result.""" + params = self.get_params(query, **kwargs) + response = requests.get( + url=SEARCH_API_URL, + params=params, + headers={"Authorization": f"Bearer {self.searchapi_api_key}"}, + ) + response.raise_for_status() + return response.json() + + def get_params(self, query: str, **kwargs: Any) -> dict[str, str]: + """Get parameters for SearchAPI.""" + return { + "engine": "google", + "q": query, + **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + } + + @staticmethod + def _process_response(res: dict, type: str) -> str: + """Process response from SearchAPI.""" + if "error" in res.keys(): + raise ValueError(f"Got error from SearchApi: {res['error']}") + + toret = "" + if type == "text": + if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): + toret += res["answer_box"]["answer"] + "\n" + if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): + toret += res["answer_box"]["snippet"] + "\n" + if "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys(): + toret += res["knowledge_graph"]["description"] + "\n" + if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): + for item in res["organic_results"]: + toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" + if toret == "": + toret = "No good search result found" + + elif type == "link": + if "answer_box" in res.keys() and "organic_result" in res["answer_box"].keys(): + if "title" in res["answer_box"]["organic_result"].keys(): + toret = f"[{res['answer_box']['organic_result']['title']}]({res['answer_box']['organic_result']['link']})\n" + elif "organic_results" in res.keys() and "link" in res["organic_results"][0].keys(): + toret = "" + for item in res["organic_results"]: + toret += f"[{item['title']}]({item['link']})\n" + elif "related_questions" in res.keys() and "link" in res["related_questions"][0].keys(): + toret = "" + for item in res["related_questions"]: + toret += f"[{item['title']}]({item['link']})\n" + elif "related_searches" in res.keys() and "link" in res["related_searches"][0].keys(): + toret = "" + for item in res["related_searches"]: + toret += f"[{item['title']}]({item['link']})\n" + else: + toret = "No good search result found" + return toret + +class GoogleTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke the SearchApi tool. + """ + query = tool_parameters['query'] + result_type = tool_parameters['result_type'] + num = tool_parameters.get("num", 10) + google_domain = tool_parameters.get("google_domain", "google.com") + gl = tool_parameters.get("gl", "us") + hl = tool_parameters.get("hl", "en") + location = tool_parameters.get("location", None) + + api_key = self.runtime.credentials['searchapi_api_key'] + result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + + if result_type == 'text': + return self.create_text_message(text=result) + return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.yaml b/api/core/tools/provider/builtin/searchapi/tools/google.yaml new file mode 100644 index 0000000000..566de84b13 --- /dev/null +++ b/api/core/tools/provider/builtin/searchapi/tools/google.yaml @@ -0,0 +1,481 @@ +identity: + name: google_search_api + author: SearchApi + label: + en_US: Google Search API + zh_Hans: Google Search API +description: + human: + en_US: A tool to retrieve answer boxes, knowledge graphs, snippets, and webpages from Google Search engine. + zh_Hans: 一种从 Google 搜索引擎检索答案框、知识图、片段和网页的工具。 + llm: A tool to retrieve answer boxes, knowledge graphs, snippets, and webpages from Google Search engine. +parameters: + - name: query + type: string + required: true + label: + en_US: Query + zh_Hans: 询问 + human_description: + en_US: Defines the query you want to search. + zh_Hans: 定义您要搜索的查询。 + llm_description: Defines the search query you want to search. + form: llm + - name: result_type + type: select + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: link + label: + en_US: link + zh_Hans: 链接 + default: text + label: + en_US: Result type + zh_Hans: 结果类型 + human_description: + en_US: used for selecting the result type, text or link + zh_Hans: 用于选择结果类型,使用文本还是链接进行展示 + form: form + - name: location + type: string + required: false + label: + en_US: Location + zh_Hans: 询问 + human_description: + en_US: Defines from where you want the search to originate. (For example - New York) + zh_Hans: 定义您想要搜索的起始位置。 (例如 - 纽约) + llm_description: Defines from where you want the search to originate. (For example - New York) + form: llm + - name: gl + type: select + label: + en_US: Country + zh_Hans: 国家 + required: false + human_description: + en_US: Defines the country of the search. Default is "US". + zh_Hans: 定义搜索的国家/地区。默认为“美国”。 + llm_description: Defines the gl parameter of the Google search. + form: form + default: US + options: + - value: AR + label: + en_US: Argentina + zh_Hans: 阿根廷 + pt_BR: Argentina + - value: AU + label: + en_US: Australia + zh_Hans: 澳大利亚 + pt_BR: Australia + - value: AT + label: + en_US: Austria + zh_Hans: 奥地利 + pt_BR: Austria + - value: BE + label: + en_US: Belgium + zh_Hans: 比利时 + pt_BR: Belgium + - value: BR + label: + en_US: Brazil + zh_Hans: 巴西 + pt_BR: Brazil + - value: CA + label: + en_US: Canada + zh_Hans: 加拿大 + pt_BR: Canada + - value: CL + label: + en_US: Chile + zh_Hans: 智利 + pt_BR: Chile + - value: CO + label: + en_US: Colombia + zh_Hans: 哥伦比亚 + pt_BR: Colombia + - value: CN + label: + en_US: China + zh_Hans: 中国 + pt_BR: China + - value: CZ + label: + en_US: Czech Republic + zh_Hans: 捷克共和国 + pt_BR: Czech Republic + - value: DK + label: + en_US: Denmark + zh_Hans: 丹麦 + pt_BR: Denmark + - value: FI + label: + en_US: Finland + zh_Hans: 芬兰 + pt_BR: Finland + - value: FR + label: + en_US: France + zh_Hans: 法国 + pt_BR: France + - value: DE + label: + en_US: Germany + zh_Hans: 德国 + pt_BR: Germany + - value: HK + label: + en_US: Hong Kong + zh_Hans: 香港 + pt_BR: Hong Kong + - value: IN + label: + en_US: India + zh_Hans: 印度 + pt_BR: India + - value: ID + label: + en_US: Indonesia + zh_Hans: 印度尼西亚 + pt_BR: Indonesia + - value: IT + label: + en_US: Italy + zh_Hans: 意大利 + pt_BR: Italy + - value: JP + label: + en_US: Japan + zh_Hans: 日本 + pt_BR: Japan + - value: KR + label: + en_US: Korea + zh_Hans: 韩国 + pt_BR: Korea + - value: MY + label: + en_US: Malaysia + zh_Hans: 马来西亚 + pt_BR: Malaysia + - value: MX + label: + en_US: Mexico + zh_Hans: 墨西哥 + pt_BR: Mexico + - value: NL + label: + en_US: Netherlands + zh_Hans: 荷兰 + pt_BR: Netherlands + - value: NZ + label: + en_US: New Zealand + zh_Hans: 新西兰 + pt_BR: New Zealand + - value: NO + label: + en_US: Norway + zh_Hans: 挪威 + pt_BR: Norway + - value: PH + label: + en_US: Philippines + zh_Hans: 菲律宾 + pt_BR: Philippines + - value: PL + label: + en_US: Poland + zh_Hans: 波兰 + pt_BR: Poland + - value: PT + label: + en_US: Portugal + zh_Hans: 葡萄牙 + pt_BR: Portugal + - value: RU + label: + en_US: Russia + zh_Hans: 俄罗斯 + pt_BR: Russia + - value: SA + label: + en_US: Saudi Arabia + zh_Hans: 沙特阿拉伯 + pt_BR: Saudi Arabia + - value: SG + label: + en_US: Singapore + zh_Hans: 新加坡 + pt_BR: Singapore + - value: ZA + label: + en_US: South Africa + zh_Hans: 南非 + pt_BR: South Africa + - value: ES + label: + en_US: Spain + zh_Hans: 西班牙 + pt_BR: Spain + - value: SE + label: + en_US: Sweden + zh_Hans: 瑞典 + pt_BR: Sweden + - value: CH + label: + en_US: Switzerland + zh_Hans: 瑞士 + pt_BR: Switzerland + - value: TW + label: + en_US: Taiwan + zh_Hans: 台湾 + pt_BR: Taiwan + - value: TH + label: + en_US: Thailand + zh_Hans: 泰国 + pt_BR: Thailand + - value: TR + label: + en_US: Turkey + zh_Hans: 土耳其 + pt_BR: Turkey + - value: GB + label: + en_US: United Kingdom + zh_Hans: 英国 + pt_BR: United Kingdom + - value: US + label: + en_US: United States + zh_Hans: 美国 + pt_BR: United States + - name: hl + type: select + label: + en_US: Language + zh_Hans: 语言 + human_description: + en_US: Defines the interface language of the search. Default is "en". + zh_Hans: 定义搜索的界面语言。默认为“en”。 + required: false + default: en + form: form + options: + - value: ar + label: + en_US: Arabic + zh_Hans: 阿拉伯语 + - value: bg + label: + en_US: Bulgarian + zh_Hans: 保加利亚语 + - value: ca + label: + en_US: Catalan + zh_Hans: 加泰罗尼亚语 + - value: zh-cn + label: + en_US: Chinese (Simplified) + zh_Hans: 中文(简体) + - value: zh-tw + label: + en_US: Chinese (Traditional) + zh_Hans: 中文(繁体) + - value: cs + label: + en_US: Czech + zh_Hans: 捷克语 + - value: da + label: + en_US: Danish + zh_Hans: 丹麦语 + - value: nl + label: + en_US: Dutch + zh_Hans: 荷兰语 + - value: en + label: + en_US: English + zh_Hans: 英语 + - value: et + label: + en_US: Estonian + zh_Hans: 爱沙尼亚语 + - value: fi + label: + en_US: Finnish + zh_Hans: 芬兰语 + - value: fr + label: + en_US: French + zh_Hans: 法语 + - value: de + label: + en_US: German + zh_Hans: 德语 + - value: el + label: + en_US: Greek + zh_Hans: 希腊语 + - value: iw + label: + en_US: Hebrew + zh_Hans: 希伯来语 + - value: hi + label: + en_US: Hindi + zh_Hans: 印地语 + - value: hu + label: + en_US: Hungarian + zh_Hans: 匈牙利语 + - value: id + label: + en_US: Indonesian + zh_Hans: 印尼语 + - value: it + label: + en_US: Italian + zh_Hans: 意大利语 + - value: jp + label: + en_US: Japanese + zh_Hans: 日语 + - value: kn + label: + en_US: Kannada + zh_Hans: 卡纳达语 + - value: ko + label: + en_US: Korean + zh_Hans: 韩语 + - value: lv + label: + en_US: Latvian + zh_Hans: 拉脱维亚语 + - value: lt + label: + en_US: Lithuanian + zh_Hans: 立陶宛语 + - value: my + label: + en_US: Malay + zh_Hans: 马来语 + - value: ml + label: + en_US: Malayalam + zh_Hans: 马拉雅拉姆语 + - value: mr + label: + en_US: Marathi + zh_Hans: 马拉地语 + - value: "no" + label: + en_US: Norwegian + zh_Hans: 挪威语 + - value: pl + label: + en_US: Polish + zh_Hans: 波兰语 + - value: pt-br + label: + en_US: Portuguese (Brazil) + zh_Hans: 葡萄牙语(巴西) + - value: pt-pt + label: + en_US: Portuguese (Portugal) + zh_Hans: 葡萄牙语(葡萄牙) + - value: pa + label: + en_US: Punjabi + zh_Hans: 旁遮普语 + - value: ro + label: + en_US: Romanian + zh_Hans: 罗马尼亚语 + - value: ru + label: + en_US: Russian + zh_Hans: 俄语 + - value: sr + label: + en_US: Serbian + zh_Hans: 塞尔维亚语 + - value: sk + label: + en_US: Slovak + zh_Hans: 斯洛伐克语 + - value: sl + label: + en_US: Slovenian + zh_Hans: 斯洛文尼亚语 + - value: es + label: + en_US: Spanish + zh_Hans: 西班牙语 + - value: sv + label: + en_US: Swedish + zh_Hans: 瑞典语 + - value: ta + label: + en_US: Tamil + zh_Hans: 泰米尔语 + - value: te + label: + en_US: Telugu + zh_Hans: 泰卢固语 + - value: th + label: + en_US: Thai + zh_Hans: 泰语 + - value: tr + label: + en_US: Turkish + zh_Hans: 土耳其语 + - value: uk + label: + en_US: Ukrainian + zh_Hans: 乌克兰语 + - value: vi + label: + en_US: Vietnamese + zh_Hans: 越南语 + - name: google_domain + type: string + required: false + label: + en_US: google_domain + zh_Hans: google_domain + human_description: + en_US: Defines the Google domain of the search. Default is "google.com". + zh_Hans: 定义搜索的 Google 域。默认为“google.com”。 + llm_description: Defines Google domain in which you want to search. + form: llm + - name: num + type: number + required: false + label: + en_US: num + zh_Hans: num + human_description: + en_US: Specifies the number of results to display per page. Default is 10. Max number - 100, min - 1. + zh_Hans: 指定每页显示的结果数。默认值为 10。最大数量 - 100,最小数量 - 1。 + llm_description: Specifies the num of results to display per page. + form: llm diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py new file mode 100644 index 0000000000..1b8cfa7e30 --- /dev/null +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -0,0 +1,88 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + +class SearchAPI: + """ + SearchAPI tool provider. + """ + + def __init__(self, api_key: str) -> None: + """Initialize SearchAPI tool provider.""" + self.searchapi_api_key = api_key + + def run(self, query: str, **kwargs: Any) -> str: + """Run query through SearchAPI and parse result.""" + type = kwargs.get("result_type", "text") + return self._process_response(self.results(query, **kwargs), type=type) + + def results(self, query: str, **kwargs: Any) -> dict: + """Run query through SearchAPI and return the raw result.""" + params = self.get_params(query, **kwargs) + response = requests.get( + url=SEARCH_API_URL, + params=params, + headers={"Authorization": f"Bearer {self.searchapi_api_key}"}, + ) + response.raise_for_status() + return response.json() + + def get_params(self, query: str, **kwargs: Any) -> dict[str, str]: + """Get parameters for SearchAPI.""" + return { + "engine": "google_jobs", + "q": query, + **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + } + + @staticmethod + def _process_response(res: dict, type: str) -> str: + """Process response from SearchAPI.""" + if "error" in res.keys(): + raise ValueError(f"Got error from SearchApi: {res['error']}") + + toret = "" + if type == "text": + if "jobs" in res.keys() and "title" in res["jobs"][0].keys(): + for item in res["jobs"]: + toret += "title: " + item["title"] + "\n" + "company_name: " + item["company_name"] + "content: " + item["description"] + "\n" + if toret == "": + toret = "No good search result found" + + elif type == "link": + if "jobs" in res.keys() and "apply_link" in res["jobs"][0].keys(): + for item in res["jobs"]: + toret += f"[{item['title']} - {item['company_name']}]({item['apply_link']})\n" + else: + toret = "No good search result found" + return toret + +class GoogleJobsTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke the SearchApi tool. + """ + query = tool_parameters['query'] + result_type = tool_parameters['result_type'] + is_remote = tool_parameters.get("is_remote", None) + google_domain = tool_parameters.get("google_domain", "google.com") + gl = tool_parameters.get("gl", "us") + hl = tool_parameters.get("hl", "en") + location = tool_parameters.get("location", None) + + ltype = 1 if is_remote else None + + api_key = self.runtime.credentials['searchapi_api_key'] + result = SearchAPI(api_key).run(query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype) + + if result_type == 'text': + return self.create_text_message(text=result) + return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.yaml b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.yaml new file mode 100644 index 0000000000..486b193efa --- /dev/null +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.yaml @@ -0,0 +1,478 @@ +identity: + name: google_jobs_api + author: SearchApi + label: + en_US: Google Jobs API + zh_Hans: Google Jobs API +description: + human: + en_US: A tool to retrieve job titles, company names and description from Google Jobs engine. + zh_Hans: 一个从 Google 招聘引擎检索职位名称、公司名称和描述的工具。 + llm: A tool to retrieve job titles, company names and description from Google Jobs engine. +parameters: + - name: query + type: string + required: true + label: + en_US: Query + zh_Hans: 询问 + human_description: + en_US: Defines the query you want to search. + zh_Hans: 定义您要搜索的查询。 + llm_description: Defines the search query you want to search. + form: llm + - name: result_type + type: select + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: link + label: + en_US: link + zh_Hans: 链接 + default: text + label: + en_US: Result type + zh_Hans: 结果类型 + human_description: + en_US: used for selecting the result type, text or link + zh_Hans: 用于选择结果类型,使用文本还是链接进行展示 + form: form + - name: location + type: string + required: false + label: + en_US: Location + zh_Hans: 询问 + human_description: + en_US: Defines from where you want the search to originate. (For example - New York) + zh_Hans: 定义您想要搜索的起始位置。 (例如 - 纽约) + llm_description: Defines from where you want the search to originate. (For example - New York) + form: llm + - name: gl + type: select + label: + en_US: Country + zh_Hans: 国家 + required: false + human_description: + en_US: Defines the country of the search. Default is "US". + zh_Hans: 定义搜索的国家/地区。默认为“美国”。 + llm_description: Defines the gl parameter of the Google search. + form: form + default: US + options: + - value: AR + label: + en_US: Argentina + zh_Hans: 阿根廷 + pt_BR: Argentina + - value: AU + label: + en_US: Australia + zh_Hans: 澳大利亚 + pt_BR: Australia + - value: AT + label: + en_US: Austria + zh_Hans: 奥地利 + pt_BR: Austria + - value: BE + label: + en_US: Belgium + zh_Hans: 比利时 + pt_BR: Belgium + - value: BR + label: + en_US: Brazil + zh_Hans: 巴西 + pt_BR: Brazil + - value: CA + label: + en_US: Canada + zh_Hans: 加拿大 + pt_BR: Canada + - value: CL + label: + en_US: Chile + zh_Hans: 智利 + pt_BR: Chile + - value: CO + label: + en_US: Colombia + zh_Hans: 哥伦比亚 + pt_BR: Colombia + - value: CN + label: + en_US: China + zh_Hans: 中国 + pt_BR: China + - value: CZ + label: + en_US: Czech Republic + zh_Hans: 捷克共和国 + pt_BR: Czech Republic + - value: DK + label: + en_US: Denmark + zh_Hans: 丹麦 + pt_BR: Denmark + - value: FI + label: + en_US: Finland + zh_Hans: 芬兰 + pt_BR: Finland + - value: FR + label: + en_US: France + zh_Hans: 法国 + pt_BR: France + - value: DE + label: + en_US: Germany + zh_Hans: 德国 + pt_BR: Germany + - value: HK + label: + en_US: Hong Kong + zh_Hans: 香港 + pt_BR: Hong Kong + - value: IN + label: + en_US: India + zh_Hans: 印度 + pt_BR: India + - value: ID + label: + en_US: Indonesia + zh_Hans: 印度尼西亚 + pt_BR: Indonesia + - value: IT + label: + en_US: Italy + zh_Hans: 意大利 + pt_BR: Italy + - value: JP + label: + en_US: Japan + zh_Hans: 日本 + pt_BR: Japan + - value: KR + label: + en_US: Korea + zh_Hans: 韩国 + pt_BR: Korea + - value: MY + label: + en_US: Malaysia + zh_Hans: 马来西亚 + pt_BR: Malaysia + - value: MX + label: + en_US: Mexico + zh_Hans: 墨西哥 + pt_BR: Mexico + - value: NL + label: + en_US: Netherlands + zh_Hans: 荷兰 + pt_BR: Netherlands + - value: NZ + label: + en_US: New Zealand + zh_Hans: 新西兰 + pt_BR: New Zealand + - value: NO + label: + en_US: Norway + zh_Hans: 挪威 + pt_BR: Norway + - value: PH + label: + en_US: Philippines + zh_Hans: 菲律宾 + pt_BR: Philippines + - value: PL + label: + en_US: Poland + zh_Hans: 波兰 + pt_BR: Poland + - value: PT + label: + en_US: Portugal + zh_Hans: 葡萄牙 + pt_BR: Portugal + - value: RU + label: + en_US: Russia + zh_Hans: 俄罗斯 + pt_BR: Russia + - value: SA + label: + en_US: Saudi Arabia + zh_Hans: 沙特阿拉伯 + pt_BR: Saudi Arabia + - value: SG + label: + en_US: Singapore + zh_Hans: 新加坡 + pt_BR: Singapore + - value: ZA + label: + en_US: South Africa + zh_Hans: 南非 + pt_BR: South Africa + - value: ES + label: + en_US: Spain + zh_Hans: 西班牙 + pt_BR: Spain + - value: SE + label: + en_US: Sweden + zh_Hans: 瑞典 + pt_BR: Sweden + - value: CH + label: + en_US: Switzerland + zh_Hans: 瑞士 + pt_BR: Switzerland + - value: TW + label: + en_US: Taiwan + zh_Hans: 台湾 + pt_BR: Taiwan + - value: TH + label: + en_US: Thailand + zh_Hans: 泰国 + pt_BR: Thailand + - value: TR + label: + en_US: Turkey + zh_Hans: 土耳其 + pt_BR: Turkey + - value: GB + label: + en_US: United Kingdom + zh_Hans: 英国 + pt_BR: United Kingdom + - value: US + label: + en_US: United States + zh_Hans: 美国 + pt_BR: United States + - name: hl + type: select + label: + en_US: Language + zh_Hans: 语言 + human_description: + en_US: Defines the interface language of the search. Default is "en". + zh_Hans: 定义搜索的界面语言。默认为“en”。 + required: false + default: en + form: form + options: + - value: ar + label: + en_US: Arabic + zh_Hans: 阿拉伯语 + - value: bg + label: + en_US: Bulgarian + zh_Hans: 保加利亚语 + - value: ca + label: + en_US: Catalan + zh_Hans: 加泰罗尼亚语 + - value: zh-cn + label: + en_US: Chinese (Simplified) + zh_Hans: 中文(简体) + - value: zh-tw + label: + en_US: Chinese (Traditional) + zh_Hans: 中文(繁体) + - value: cs + label: + en_US: Czech + zh_Hans: 捷克语 + - value: da + label: + en_US: Danish + zh_Hans: 丹麦语 + - value: nl + label: + en_US: Dutch + zh_Hans: 荷兰语 + - value: en + label: + en_US: English + zh_Hans: 英语 + - value: et + label: + en_US: Estonian + zh_Hans: 爱沙尼亚语 + - value: fi + label: + en_US: Finnish + zh_Hans: 芬兰语 + - value: fr + label: + en_US: French + zh_Hans: 法语 + - value: de + label: + en_US: German + zh_Hans: 德语 + - value: el + label: + en_US: Greek + zh_Hans: 希腊语 + - value: iw + label: + en_US: Hebrew + zh_Hans: 希伯来语 + - value: hi + label: + en_US: Hindi + zh_Hans: 印地语 + - value: hu + label: + en_US: Hungarian + zh_Hans: 匈牙利语 + - value: id + label: + en_US: Indonesian + zh_Hans: 印尼语 + - value: it + label: + en_US: Italian + zh_Hans: 意大利语 + - value: jp + label: + en_US: Japanese + zh_Hans: 日语 + - value: kn + label: + en_US: Kannada + zh_Hans: 卡纳达语 + - value: ko + label: + en_US: Korean + zh_Hans: 韩语 + - value: lv + label: + en_US: Latvian + zh_Hans: 拉脱维亚语 + - value: lt + label: + en_US: Lithuanian + zh_Hans: 立陶宛语 + - value: my + label: + en_US: Malay + zh_Hans: 马来语 + - value: ml + label: + en_US: Malayalam + zh_Hans: 马拉雅拉姆语 + - value: mr + label: + en_US: Marathi + zh_Hans: 马拉地语 + - value: "no" + label: + en_US: Norwegian + zh_Hans: 挪威语 + - value: pl + label: + en_US: Polish + zh_Hans: 波兰语 + - value: pt-br + label: + en_US: Portuguese (Brazil) + zh_Hans: 葡萄牙语(巴西) + - value: pt-pt + label: + en_US: Portuguese (Portugal) + zh_Hans: 葡萄牙语(葡萄牙) + - value: pa + label: + en_US: Punjabi + zh_Hans: 旁遮普语 + - value: ro + label: + en_US: Romanian + zh_Hans: 罗马尼亚语 + - value: ru + label: + en_US: Russian + zh_Hans: 俄语 + - value: sr + label: + en_US: Serbian + zh_Hans: 塞尔维亚语 + - value: sk + label: + en_US: Slovak + zh_Hans: 斯洛伐克语 + - value: sl + label: + en_US: Slovenian + zh_Hans: 斯洛文尼亚语 + - value: es + label: + en_US: Spanish + zh_Hans: 西班牙语 + - value: sv + label: + en_US: Swedish + zh_Hans: 瑞典语 + - value: ta + label: + en_US: Tamil + zh_Hans: 泰米尔语 + - value: te + label: + en_US: Telugu + zh_Hans: 泰卢固语 + - value: th + label: + en_US: Thai + zh_Hans: 泰语 + - value: tr + label: + en_US: Turkish + zh_Hans: 土耳其语 + - value: uk + label: + en_US: Ukrainian + zh_Hans: 乌克兰语 + - value: vi + label: + en_US: Vietnamese + zh_Hans: 越南语 + - name: is_remote + type: select + label: + en_US: is_remote + zh_Hans: 很遥远 + human_description: + en_US: Filter results based on the work arrangement. Set it to true to find jobs that offer work from home or remote work opportunities. + zh_Hans: 根据工作安排过滤结果。将其设置为 true 可查找提供在家工作或远程工作机会的工作。 + required: false + form: form + options: + - value: true + label: + en_US: "true" + zh_Hans: "true" + - value: false + label: + en_US: "false" + zh_Hans: "false" diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py new file mode 100644 index 0000000000..d592dc25aa --- /dev/null +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -0,0 +1,92 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + +class SearchAPI: + """ + SearchAPI tool provider. + """ + + def __init__(self, api_key: str) -> None: + """Initialize SearchAPI tool provider.""" + self.searchapi_api_key = api_key + + def run(self, query: str, **kwargs: Any) -> str: + """Run query through SearchAPI and parse result.""" + type = kwargs.get("result_type", "text") + return self._process_response(self.results(query, **kwargs), type=type) + + def results(self, query: str, **kwargs: Any) -> dict: + """Run query through SearchAPI and return the raw result.""" + params = self.get_params(query, **kwargs) + response = requests.get( + url=SEARCH_API_URL, + params=params, + headers={"Authorization": f"Bearer {self.searchapi_api_key}"}, + ) + response.raise_for_status() + return response.json() + + def get_params(self, query: str, **kwargs: Any) -> dict[str, str]: + """Get parameters for SearchAPI.""" + return { + "engine": "google_news", + "q": query, + **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + } + + @staticmethod + def _process_response(res: dict, type: str) -> str: + """Process response from SearchAPI.""" + if "error" in res.keys(): + raise ValueError(f"Got error from SearchApi: {res['error']}") + + toret = "" + if type == "text": + if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): + for item in res["organic_results"]: + toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" + if "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): + for item in res["top_stories"]: + toret += "title: " + item["title"] + "\n" + "link: " + item["link"] + "\n" + if toret == "": + toret = "No good search result found" + + elif type == "link": + if "organic_results" in res.keys() and "title" in res["organic_results"][0].keys(): + for item in res["organic_results"]: + toret += f"[{item['title']}]({item['link']})\n" + elif "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): + for item in res["top_stories"]: + toret += f"[{item['title']}]({item['link']})\n" + else: + toret = "No good search result found" + return toret + +class GoogleNewsTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke the SearchApi tool. + """ + query = tool_parameters['query'] + result_type = tool_parameters['result_type'] + num = tool_parameters.get("num", 10) + google_domain = tool_parameters.get("google_domain", "google.com") + gl = tool_parameters.get("gl", "us") + hl = tool_parameters.get("hl", "en") + location = tool_parameters.get("location", None) + + api_key = self.runtime.credentials['searchapi_api_key'] + result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + + if result_type == 'text': + return self.create_text_message(text=result) + return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.yaml b/api/core/tools/provider/builtin/searchapi/tools/google_news.yaml new file mode 100644 index 0000000000..b0212952e6 --- /dev/null +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.yaml @@ -0,0 +1,482 @@ +identity: + name: google_news_api + author: SearchApi + label: + en_US: Google News API + zh_Hans: Google News API +description: + human: + en_US: A tool to retrieve organic search results snippets and links from Google News engine. + zh_Hans: 一种从 Google 新闻引擎检索有机搜索结果片段和链接的工具。 + llm: A tool to retrieve organic search results snippets and links from Google News engine. +parameters: + - name: query + type: string + required: true + label: + en_US: Query + zh_Hans: 询问 + human_description: + en_US: Defines the query you want to search. + zh_Hans: 定义您要搜索的查询。 + llm_description: Defines the search query you want to search. + form: llm + - name: result_type + type: select + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: link + label: + en_US: link + zh_Hans: 链接 + default: text + label: + en_US: Result type + zh_Hans: 结果类型 + human_description: + en_US: used for selecting the result type, text or link. + zh_Hans: 用于选择结果类型,使用文本还是链接进行展示。 + form: form + - name: location + type: string + required: false + label: + en_US: Location + zh_Hans: 询问 + human_description: + en_US: Defines from where you want the search to originate. (For example - New York) + zh_Hans: 定义您想要搜索的起始位置。 (例如 - 纽约) + llm_description: Defines from where you want the search to originate. (For example - New York) + form: llm + - name: gl + type: select + label: + en_US: Country + zh_Hans: 国家 + required: false + human_description: + en_US: Defines the country of the search. Default is "US". + zh_Hans: 定义搜索的国家/地区。默认为“美国”。 + llm_description: Defines the gl parameter of the Google search. + form: form + default: US + options: + - value: AR + label: + en_US: Argentina + zh_Hans: 阿根廷 + pt_BR: Argentina + - value: AU + label: + en_US: Australia + zh_Hans: 澳大利亚 + pt_BR: Australia + - value: AT + label: + en_US: Austria + zh_Hans: 奥地利 + pt_BR: Austria + - value: BE + label: + en_US: Belgium + zh_Hans: 比利时 + pt_BR: Belgium + - value: BR + label: + en_US: Brazil + zh_Hans: 巴西 + pt_BR: Brazil + - value: CA + label: + en_US: Canada + zh_Hans: 加拿大 + pt_BR: Canada + - value: CL + label: + en_US: Chile + zh_Hans: 智利 + pt_BR: Chile + - value: CO + label: + en_US: Colombia + zh_Hans: 哥伦比亚 + pt_BR: Colombia + - value: CN + label: + en_US: China + zh_Hans: 中国 + pt_BR: China + - value: CZ + label: + en_US: Czech Republic + zh_Hans: 捷克共和国 + pt_BR: Czech Republic + - value: DK + label: + en_US: Denmark + zh_Hans: 丹麦 + pt_BR: Denmark + - value: FI + label: + en_US: Finland + zh_Hans: 芬兰 + pt_BR: Finland + - value: FR + label: + en_US: France + zh_Hans: 法国 + pt_BR: France + - value: DE + label: + en_US: Germany + zh_Hans: 德国 + pt_BR: Germany + - value: HK + label: + en_US: Hong Kong + zh_Hans: 香港 + pt_BR: Hong Kong + - value: IN + label: + en_US: India + zh_Hans: 印度 + pt_BR: India + - value: ID + label: + en_US: Indonesia + zh_Hans: 印度尼西亚 + pt_BR: Indonesia + - value: IT + label: + en_US: Italy + zh_Hans: 意大利 + pt_BR: Italy + - value: JP + label: + en_US: Japan + zh_Hans: 日本 + pt_BR: Japan + - value: KR + label: + en_US: Korea + zh_Hans: 韩国 + pt_BR: Korea + - value: MY + label: + en_US: Malaysia + zh_Hans: 马来西亚 + pt_BR: Malaysia + - value: MX + label: + en_US: Mexico + zh_Hans: 墨西哥 + pt_BR: Mexico + - value: NL + label: + en_US: Netherlands + zh_Hans: 荷兰 + pt_BR: Netherlands + - value: NZ + label: + en_US: New Zealand + zh_Hans: 新西兰 + pt_BR: New Zealand + - value: NO + label: + en_US: Norway + zh_Hans: 挪威 + pt_BR: Norway + - value: PH + label: + en_US: Philippines + zh_Hans: 菲律宾 + pt_BR: Philippines + - value: PL + label: + en_US: Poland + zh_Hans: 波兰 + pt_BR: Poland + - value: PT + label: + en_US: Portugal + zh_Hans: 葡萄牙 + pt_BR: Portugal + - value: RU + label: + en_US: Russia + zh_Hans: 俄罗斯 + pt_BR: Russia + - value: SA + label: + en_US: Saudi Arabia + zh_Hans: 沙特阿拉伯 + pt_BR: Saudi Arabia + - value: SG + label: + en_US: Singapore + zh_Hans: 新加坡 + pt_BR: Singapore + - value: ZA + label: + en_US: South Africa + zh_Hans: 南非 + pt_BR: South Africa + - value: ES + label: + en_US: Spain + zh_Hans: 西班牙 + pt_BR: Spain + - value: SE + label: + en_US: Sweden + zh_Hans: 瑞典 + pt_BR: Sweden + - value: CH + label: + en_US: Switzerland + zh_Hans: 瑞士 + pt_BR: Switzerland + - value: TW + label: + en_US: Taiwan + zh_Hans: 台湾 + pt_BR: Taiwan + - value: TH + label: + en_US: Thailand + zh_Hans: 泰国 + pt_BR: Thailand + - value: TR + label: + en_US: Turkey + zh_Hans: 土耳其 + pt_BR: Turkey + - value: GB + label: + en_US: United Kingdom + zh_Hans: 英国 + pt_BR: United Kingdom + - value: US + label: + en_US: United States + zh_Hans: 美国 + pt_BR: United States + - name: hl + type: select + label: + en_US: Language + zh_Hans: 语言 + human_description: + en_US: Defines the interface language of the search. Default is "en". + zh_Hans: 定义搜索的界面语言。默认为“en”。 + required: false + default: en + form: form + options: + - value: ar + label: + en_US: Arabic + zh_Hans: 阿拉伯语 + - value: bg + label: + en_US: Bulgarian + zh_Hans: 保加利亚语 + - value: ca + label: + en_US: Catalan + zh_Hans: 加泰罗尼亚语 + - value: zh-cn + label: + en_US: Chinese (Simplified) + zh_Hans: 中文(简体) + - value: zh-tw + label: + en_US: Chinese (Traditional) + zh_Hans: 中文(繁体) + - value: cs + label: + en_US: Czech + zh_Hans: 捷克语 + - value: da + label: + en_US: Danish + zh_Hans: 丹麦语 + - value: nl + label: + en_US: Dutch + zh_Hans: 荷兰语 + - value: en + label: + en_US: English + zh_Hans: 英语 + - value: et + label: + en_US: Estonian + zh_Hans: 爱沙尼亚语 + - value: fi + label: + en_US: Finnish + zh_Hans: 芬兰语 + - value: fr + label: + en_US: French + zh_Hans: 法语 + - value: de + label: + en_US: German + zh_Hans: 德语 + - value: el + label: + en_US: Greek + zh_Hans: 希腊语 + - value: iw + label: + en_US: Hebrew + zh_Hans: 希伯来语 + - value: hi + label: + en_US: Hindi + zh_Hans: 印地语 + - value: hu + label: + en_US: Hungarian + zh_Hans: 匈牙利语 + - value: id + label: + en_US: Indonesian + zh_Hans: 印尼语 + - value: it + label: + en_US: Italian + zh_Hans: 意大利语 + - value: jp + label: + en_US: Japanese + zh_Hans: 日语 + - value: kn + label: + en_US: Kannada + zh_Hans: 卡纳达语 + - value: ko + label: + en_US: Korean + zh_Hans: 韩语 + - value: lv + label: + en_US: Latvian + zh_Hans: 拉脱维亚语 + - value: lt + label: + en_US: Lithuanian + zh_Hans: 立陶宛语 + - value: my + label: + en_US: Malay + zh_Hans: 马来语 + - value: ml + label: + en_US: Malayalam + zh_Hans: 马拉雅拉姆语 + - value: mr + label: + en_US: Marathi + zh_Hans: 马拉地语 + - value: "no" + label: + en_US: Norwegian + zh_Hans: 挪威语 + - value: pl + label: + en_US: Polish + zh_Hans: 波兰语 + - value: pt-br + label: + en_US: Portuguese (Brazil) + zh_Hans: 葡萄牙语(巴西) + - value: pt-pt + label: + en_US: Portuguese (Portugal) + zh_Hans: 葡萄牙语(葡萄牙) + - value: pa + label: + en_US: Punjabi + zh_Hans: 旁遮普语 + - value: ro + label: + en_US: Romanian + zh_Hans: 罗马尼亚语 + - value: ru + label: + en_US: Russian + zh_Hans: 俄语 + - value: sr + label: + en_US: Serbian + zh_Hans: 塞尔维亚语 + - value: sk + label: + en_US: Slovak + zh_Hans: 斯洛伐克语 + - value: sl + label: + en_US: Slovenian + zh_Hans: 斯洛文尼亚语 + - value: es + label: + en_US: Spanish + zh_Hans: 西班牙语 + - value: sv + label: + en_US: Swedish + zh_Hans: 瑞典语 + - value: ta + label: + en_US: Tamil + zh_Hans: 泰米尔语 + - value: te + label: + en_US: Telugu + zh_Hans: 泰卢固语 + - value: th + label: + en_US: Thai + zh_Hans: 泰语 + - value: tr + label: + en_US: Turkish + zh_Hans: 土耳其语 + - value: uk + label: + en_US: Ukrainian + zh_Hans: 乌克兰语 + - value: vi + label: + en_US: Vietnamese + zh_Hans: 越南语 + - name: google_domain + type: string + required: false + label: + en_US: google_domain + zh_Hans: google_domain + human_description: + en_US: Defines the Google domain of the search. Default is "google.com". + zh_Hans: 定义搜索的 Google 域。默认为“google.com”。 + llm_description: Defines Google domain in which you want to search. + form: llm + - name: num + type: number + required: false + label: + en_US: num + zh_Hans: num + human_description: + en_US: Specifies the number of results to display per page. Default is 10. Max number - 100, min - 1. + zh_Hans: 指定每页显示的结果数。默认值为 10。最大数量 - 100,最小数量 - 1。 + pt_BR: Specifies the number of results to display per page. Default is 10. Max number - 100, min - 1. + llm_description: Specifies the num of results to display per page. + form: llm diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py new file mode 100644 index 0000000000..6345b33801 --- /dev/null +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -0,0 +1,72 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + +class SearchAPI: + """ + SearchAPI tool provider. + """ + + def __init__(self, api_key: str) -> None: + """Initialize SearchAPI tool provider.""" + self.searchapi_api_key = api_key + + def run(self, video_id: str, language: str, **kwargs: Any) -> str: + """Run video_id through SearchAPI and parse result.""" + return self._process_response(self.results(video_id, language, **kwargs)) + + def results(self, video_id: str, language: str, **kwargs: Any) -> dict: + """Run video_id through SearchAPI and return the raw result.""" + params = self.get_params(video_id, language, **kwargs) + response = requests.get( + url=SEARCH_API_URL, + params=params, + headers={"Authorization": f"Bearer {self.searchapi_api_key}"}, + ) + response.raise_for_status() + return response.json() + + def get_params(self, video_id: str, language: str, **kwargs: Any) -> dict[str, str]: + """Get parameters for SearchAPI.""" + return { + "engine": "youtube_transcripts", + "video_id": video_id, + "lang": language if language else "en", + **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + } + + @staticmethod + def _process_response(res: dict) -> str: + """Process response from SearchAPI.""" + if "error" in res.keys(): + raise ValueError(f"Got error from SearchApi: {res['error']}") + + toret = "" + if "transcripts" in res.keys() and "text" in res["transcripts"][0].keys(): + for item in res["transcripts"]: + toret += item["text"] + " " + if toret == "": + toret = "No good search result found" + + return toret + +class YoutubeTranscriptsTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke the SearchApi tool. + """ + video_id = tool_parameters['video_id'] + language = tool_parameters.get('language', "en") + + api_key = self.runtime.credentials['searchapi_api_key'] + result = SearchAPI(api_key).run(video_id, language=language) + + return self.create_text_message(text=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.yaml b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.yaml new file mode 100644 index 0000000000..8bdcd6bb93 --- /dev/null +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.yaml @@ -0,0 +1,34 @@ +identity: + name: youtube_transcripts_api + author: SearchApi + label: + en_US: YouTube Transcripts API + zh_Hans: YouTube 脚本 API +description: + human: + en_US: A tool to retrieve transcripts from the specific YouTube video. + zh_Hans: 一种从特定 YouTube 视频检索文字记录的工具。 + llm: A tool to retrieve transcripts from the specific YouTube video. +parameters: + - name: video_id + type: string + required: true + label: + en_US: video_id + zh_Hans: 视频ID + human_description: + en_US: Used to define the video you want to search. You can find the video id's in YouTube page that appears in URL. For example - https://www.youtube.com/watch?v=video_id. + zh_Hans: 用于定义要搜索的视频。您可以在 URL 中显示的 YouTube 页面中找到视频 ID。例如 - https://www.youtube.com/watch?v=video_id。 + llm_description: Used to define the video you want to search. + form: llm + - name: language + type: string + required: false + label: + en_US: language + zh_Hans: 语言 + human_description: + en_US: Used to set the language for transcripts. The default value is "en". You can find all supported languages in SearchApi documentation. + zh_Hans: 用于设置成绩单的语言。默认值为“en”。您可以在 SearchApi 文档中找到所有支持的语言。 + llm_description: Used to set the language for transcripts. + form: llm diff --git a/api/core/tools/provider/builtin/searxng/searxng.py b/api/core/tools/provider/builtin/searxng/searxng.py index 8046056093..24b94b5ca4 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.py +++ b/api/core/tools/provider/builtin/searxng/searxng.py @@ -9,7 +9,7 @@ class SearXNGProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: SearXNGSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -22,4 +22,4 @@ class SearXNGProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searxng/searxng.yaml b/api/core/tools/provider/builtin/searxng/searxng.yaml index c8c713cf04..64bd428280 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.yaml +++ b/api/core/tools/provider/builtin/searxng/searxng.yaml @@ -8,6 +8,9 @@ identity: en_US: A free internet metasearch engine. zh_Hans: 开源互联网元搜索引擎 icon: icon.svg + tags: + - search + - productivity credentials_for_provider: searxng_base_url: type: secret-input diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py index cbc5ab435a..50f04760a7 100644 --- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py @@ -121,4 +121,5 @@ class SearXNGSearchTool(BuiltinTool): query=query, search_type=search_type, result_type=result_type, - topK=num_results) \ No newline at end of file + topK=num_results + ) diff --git a/api/core/tools/provider/builtin/slack/slack.yaml b/api/core/tools/provider/builtin/slack/slack.yaml index 7278793900..1070ffbf03 100644 --- a/api/core/tools/provider/builtin/slack/slack.yaml +++ b/api/core/tools/provider/builtin/slack/slack.yaml @@ -10,4 +10,7 @@ identity: zh_Hans: Slack Webhook pt_BR: Slack Webhook icon: icon.svg + tags: + - social + - productivity credentials_for_provider: diff --git a/api/core/tools/provider/builtin/spark/spark.yaml b/api/core/tools/provider/builtin/spark/spark.yaml index f2b9c89e96..fa1543443a 100644 --- a/api/core/tools/provider/builtin/spark/spark.yaml +++ b/api/core/tools/provider/builtin/spark/spark.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: 讯飞星火平台工具 pt_BR: Pacote de Ferramentas da Plataforma Spark icon: icon.svg + tags: + - image credentials_for_provider: APPID: type: secret-input diff --git a/api/core/tools/provider/builtin/stability/stability.py b/api/core/tools/provider/builtin/stability/stability.py index d00c3ecf00..b31d786178 100644 --- a/api/core/tools/provider/builtin/stability/stability.py +++ b/api/core/tools/provider/builtin/stability/stability.py @@ -12,4 +12,4 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz """ This method is responsible for validating the credentials. """ - self.sd_validate_credentials(credentials) \ No newline at end of file + self.sd_validate_credentials(credentials) diff --git a/api/core/tools/provider/builtin/stability/stability.yaml b/api/core/tools/provider/builtin/stability/stability.yaml index d8369a4c03..c3e01c1e31 100644 --- a/api/core/tools/provider/builtin/stability/stability.yaml +++ b/api/core/tools/provider/builtin/stability/stability.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: 通过生成式 AI 激活人类的潜力 pt_BR: Activating humanity's potential through generative AI icon: icon.svg + tags: + - image credentials_for_provider: api_key: type: secret-input diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py index 5748e8d4e2..317d705f7c 100644 --- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py @@ -9,9 +9,10 @@ class StableDiffusionProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: StableDiffusionTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).validate_models() except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml index e1161da5bb..9b3c804f72 100644 --- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: Stable Diffusion 是一个可以在本地部署的图片生成的工具。 pt_BR: Stable Diffusion is a tool for generating images which can be deployed locally. icon: icon.png + tags: + - image credentials_for_provider: base_url: type: secret-input diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 25206da0bc..0c5ebc23ac 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -13,49 +13,76 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, from core.tools.errors import ToolProviderCredentialValidationError from core.tools.tool.builtin_tool import BuiltinTool +# All commented out parameters default to null DRAW_TEXT_OPTIONS = { + # Prompts "prompt": "", "negative_prompt": "", + # "styles": [], + # Seeds "seed": -1, "subseed": -1, "subseed_strength": 0, "seed_resize_from_h": -1, - 'sampler_index': 'DPM++ SDE Karras', "seed_resize_from_w": -1, + + # Samplers + # "sampler_name": "DPM++ 2M", + # "scheduler": "", + # "sampler_index": "Automatic", + + # Latent Space Options "batch_size": 1, "n_iter": 1, "steps": 10, "cfg_scale": 7, - "width": 1024, - "height": 1024, - "restore_faces": False, + "width": 512, + "height": 512, + # "restore_faces": True, + # "tiling": True, "do_not_save_samples": False, "do_not_save_grid": False, - "eta": 0, - "denoising_strength": 0, - "s_min_uncond": 0, - "s_churn": 0, - "s_tmax": 0, - "s_tmin": 0, - "s_noise": 0, + # "eta": 0, + # "denoising_strength": 0.75, + # "s_min_uncond": 0, + # "s_churn": 0, + # "s_tmax": 0, + # "s_tmin": 0, + # "s_noise": 0, "override_settings": {}, "override_settings_restore_afterwards": True, + # Refinement Options + "refiner_checkpoint": "", "refiner_switch_at": 0, "disable_extra_networks": False, - "comments": {}, + # "firstpass_image": "", + # "comments": "", + # High-Resolution Options "enable_hr": False, "firstphase_width": 0, "firstphase_height": 0, "hr_scale": 2, + # "hr_upscaler": "", "hr_second_pass_steps": 0, "hr_resize_x": 0, "hr_resize_y": 0, + # "hr_checkpoint_name": "", + # "hr_sampler_name": "", + # "hr_scheduler": "", "hr_prompt": "", "hr_negative_prompt": "", + # Task Options + # "force_task_id": "", + + # Script Options + # "script_name": "", "script_args": [], + # Output Options "send_images": True, "save_images": False, - "alwayson_scripts": {} + "alwayson_scripts": {}, + # "infotext": "", + } @@ -70,7 +97,7 @@ class StableDiffusionTool(BuiltinTool): if not base_url: return self.create_text_message('Please input base_url') - if 'model' in tool_parameters and tool_parameters['model']: + if tool_parameters.get('model'): self.runtime.credentials['model'] = tool_parameters['model'] model = self.runtime.credentials.get('model', None) @@ -88,60 +115,15 @@ class StableDiffusionTool(BuiltinTool): except Exception as e: raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') - - # prompt - prompt = tool_parameters.get('prompt', '') - if not prompt: - return self.create_text_message('Please input prompt') - - # get negative prompt - negative_prompt = tool_parameters.get('negative_prompt', '') - - # get size - width = tool_parameters.get('width', 1024) - height = tool_parameters.get('height', 1024) - - # get steps - steps = tool_parameters.get('steps', 1) - - # get lora - lora = tool_parameters.get('lora', '') - - # get image id + # get image id and image variable image_id = tool_parameters.get('image_id', '') - if image_id.strip(): - image_variable = self.get_default_image_variable() - if image_variable: - image_binary = self.get_variable_file(image_variable.name) - if not image_binary: - return self.create_text_message('Image not found, please request user to generate image firstly.') - - # convert image to RGB - image = Image.open(io.BytesIO(image_binary)) - image = image.convert("RGB") - buffer = io.BytesIO() - image.save(buffer, format="PNG") - image_binary = buffer.getvalue() - image.close() + image_variable = self.get_default_image_variable() + # Return text2img if there's no image ID or no image variable + if not image_id or not image_variable: + return self.text2img(base_url=base_url,tool_parameters=tool_parameters) - return self.img2img(base_url=base_url, - lora=lora, - image_binary=image_binary, - prompt=prompt, - negative_prompt=negative_prompt, - width=width, - height=height, - steps=steps, - model=model) - - return self.text2img(base_url=base_url, - lora=lora, - prompt=prompt, - negative_prompt=negative_prompt, - width=width, - height=height, - steps=steps, - model=model) + # Proceed with image-to-image generation + return self.img2img(base_url=base_url,tool_parameters=tool_parameters) def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ @@ -197,35 +179,67 @@ class StableDiffusionTool(BuiltinTool): except Exception as e: return [] - def img2img(self, base_url: str, lora: str, image_binary: bytes, - prompt: str, negative_prompt: str, - width: int, height: int, steps: int, model: str) \ + def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ generate image """ - draw_options = { - "init_images": [b64encode(image_binary).decode('utf-8')], - "prompt": "", - "negative_prompt": negative_prompt, - "denoising_strength": 0.9, - "width": width, - "height": height, - "cfg_scale": 7, - "sampler_name": "Euler a", - "restore_faces": False, - "steps": steps, - "script_args": ["outpainting mk2"], - "override_settings": {"sd_model_checkpoint": model} - } + # Fetch the binary data of the image + image_variable = self.get_default_image_variable() + image_binary = self.get_variable_file(image_variable.name) + if not image_binary: + return self.create_text_message('Image not found, please request user to generate image firstly.') + + # Convert image to RGB and save as PNG + try: + with Image.open(io.BytesIO(image_binary)) as image: + with io.BytesIO() as buffer: + image.convert("RGB").save(buffer, format="PNG") + image_binary = buffer.getvalue() + except Exception as e: + return self.create_text_message(f"Failed to process the image: {str(e)}") + + # copy draw options + draw_options = deepcopy(DRAW_TEXT_OPTIONS) + # set image options + model = tool_parameters.get('model', '') + draw_options_image = { + "init_images": [b64encode(image_binary).decode('utf-8')], + "denoising_strength": 0.9, + "restore_faces": False, + "script_args": [], + "override_settings": {"sd_model_checkpoint": model}, + "resize_mode":0, + "image_cfg_scale": 0, + # "mask": None, + "mask_blur_x": 4, + "mask_blur_y": 4, + "mask_blur": 0, + "mask_round": True, + "inpainting_fill": 0, + "inpaint_full_res": True, + "inpaint_full_res_padding": 0, + "inpainting_mask_invert": 0, + "initial_noise_multiplier": 0, + # "latent_mask": None, + "include_init_images": True, + } + # update key and values + draw_options.update(draw_options_image) + draw_options.update(tool_parameters) + + # get prompt lora model + prompt = tool_parameters.get('prompt', '') + lora = tool_parameters.get('lora', '') + model = tool_parameters.get('model', '') if lora: draw_options['prompt'] = f'{lora},{prompt}' else: draw_options['prompt'] = prompt try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img') + url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img') response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: return self.create_text_message('Failed to generate image') @@ -239,24 +253,24 @@ class StableDiffusionTool(BuiltinTool): except Exception as e: return self.create_text_message('Failed to generate image') - def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int, model: str) \ + def text2img(self, base_url: str, tool_parameters: dict[str, Any]) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ generate image """ # copy draw options draw_options = deepcopy(DRAW_TEXT_OPTIONS) - + draw_options.update(tool_parameters) + # get prompt lora model + prompt = tool_parameters.get('prompt', '') + lora = tool_parameters.get('lora', '') + model = tool_parameters.get('model', '') if lora: draw_options['prompt'] = f'{lora},{prompt}' else: draw_options['prompt'] = prompt - - draw_options['width'] = width - draw_options['height'] = height - draw_options['steps'] = steps - draw_options['negative_prompt'] = negative_prompt draw_options['override_settings']['sd_model_checkpoint'] = model + try: url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img') diff --git a/api/core/tools/provider/builtin/stackexchange/stackexchange.py b/api/core/tools/provider/builtin/stackexchange/stackexchange.py index fab543c580..de64c84997 100644 --- a/api/core/tools/provider/builtin/stackexchange/stackexchange.py +++ b/api/core/tools/provider/builtin/stackexchange/stackexchange.py @@ -7,7 +7,7 @@ class StackExchangeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: SearchStackExQuestionsTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -22,4 +22,5 @@ class StackExchangeProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stackexchange/stackexchange.yaml b/api/core/tools/provider/builtin/stackexchange/stackexchange.yaml index b2c8fe296f..ad64b9595a 100644 --- a/api/core/tools/provider/builtin/stackexchange/stackexchange.yaml +++ b/api/core/tools/provider/builtin/stackexchange/stackexchange.yaml @@ -8,3 +8,6 @@ identity: en_US: Access questions and answers from the Stack Exchange and its sub-sites. zh_Hans: 从Stack Exchange和其子论坛获取问题和答案。 icon: icon.svg + tags: + - search + - utilities diff --git a/api/core/tools/provider/builtin/tavily/tavily.py b/api/core/tools/provider/builtin/tavily/tavily.py index 575d9268b9..e376d99d6b 100644 --- a/api/core/tools/provider/builtin/tavily/tavily.py +++ b/api/core/tools/provider/builtin/tavily/tavily.py @@ -9,7 +9,7 @@ class TavilyProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: TavilySearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -26,4 +26,5 @@ class TavilyProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/tavily/tavily.yaml b/api/core/tools/provider/builtin/tavily/tavily.yaml index 50826e37b3..7b25a81848 100644 --- a/api/core/tools/provider/builtin/tavily/tavily.yaml +++ b/api/core/tools/provider/builtin/tavily/tavily.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: Tavily pt_BR: Tavily icon: icon.png + tags: + - search credentials_for_provider: tavily_api_key: type: secret-input diff --git a/api/core/tools/provider/builtin/time/time.py b/api/core/tools/provider/builtin/time/time.py index 0d3285f495..833ae194ef 100644 --- a/api/core/tools/provider/builtin/time/time.py +++ b/api/core/tools/provider/builtin/time/time.py @@ -13,4 +13,5 @@ class WikiPediaProvider(BuiltinToolProviderController): tool_parameters={}, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/time/time.yaml b/api/core/tools/provider/builtin/time/time.yaml index 8ea67cded5..1278939df5 100644 --- a/api/core/tools/provider/builtin/time/time.yaml +++ b/api/core/tools/provider/builtin/time/time.yaml @@ -10,4 +10,6 @@ identity: zh_Hans: 一个用于获取当前时间的工具。 pt_BR: A tool for getting the current time. icon: icon.svg + tags: + - utilities credentials_for_provider: diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py index 8722274565..90c01665e6 100644 --- a/api/core/tools/provider/builtin/time/tools/current_time.py +++ b/api/core/tools/provider/builtin/time/tools/current_time.py @@ -17,11 +17,12 @@ class CurrentTimeTool(BuiltinTool): """ # get timezone tz = tool_parameters.get('timezone', 'UTC') + fm = tool_parameters.get('format') or '%Y-%m-%d %H:%M:%S %Z' if tz == 'UTC': - return self.create_text_message(f'{datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")}') - + return self.create_text_message(f'{datetime.now(timezone.utc).strftime(fm)}') + try: tz = pytz_timezone(tz) except: return self.create_text_message(f'Invalid timezone: {tz}') - return self.create_text_message(f'{datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S %Z")}') \ No newline at end of file + return self.create_text_message(f'{datetime.now(tz).strftime(fm)}') \ No newline at end of file diff --git a/api/core/tools/provider/builtin/time/tools/current_time.yaml b/api/core/tools/provider/builtin/time/tools/current_time.yaml index f0b5f53bd8..dbdd39e223 100644 --- a/api/core/tools/provider/builtin/time/tools/current_time.yaml +++ b/api/core/tools/provider/builtin/time/tools/current_time.yaml @@ -12,6 +12,19 @@ description: pt_BR: A tool for getting the current time. llm: A tool for getting the current time. parameters: + - name: format + type: string + required: false + label: + en_US: Format + zh_Hans: 格式 + pt_BR: Format + human_description: + en_US: Time format in strftime standard. + zh_Hans: strftime 标准的时间格式。 + pt_BR: Time format in strftime standard. + form: form + default: "%Y-%m-%d %H:%M:%S" - name: timezone type: select required: false @@ -46,6 +59,11 @@ parameters: en_US: America/Chicago zh_Hans: 美洲/芝加哥 pt_BR: America/Chicago + - value: America/Sao_Paulo + label: + en_US: America/Sao_Paulo + zh_Hans: 美洲/圣保罗 + pt_BR: América/São Paulo - value: Asia/Shanghai label: en_US: Asia/Shanghai diff --git a/api/core/tools/provider/builtin/trello/trello.py b/api/core/tools/provider/builtin/trello/trello.py index d27115d246..84ecd20803 100644 --- a/api/core/tools/provider/builtin/trello/trello.py +++ b/api/core/tools/provider/builtin/trello/trello.py @@ -31,4 +31,5 @@ class TrelloProvider(BuiltinToolProviderController): raise ToolProviderCredentialValidationError("Error validating Trello credentials") except requests.exceptions.RequestException as e: # Handle other exceptions, such as connection errors - raise ToolProviderCredentialValidationError("Error validating Trello credentials") \ No newline at end of file + raise ToolProviderCredentialValidationError("Error validating Trello credentials") + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/trello/trello.yaml b/api/core/tools/provider/builtin/trello/trello.yaml index e1228c16be..49c9f4f9a1 100644 --- a/api/core/tools/provider/builtin/trello/trello.yaml +++ b/api/core/tools/provider/builtin/trello/trello.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: "Trello: 一个用于组织工作和生活的视觉工具。" pt_BR: "Trello: Uma ferramenta visual para organizar seu trabalho e vida." icon: icon.svg + tags: + - productivity credentials_for_provider: trello_api_key: type: secret-input diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 24502a3565..48e3876a4a 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -30,7 +30,7 @@ class TwilioAPIWrapper(BaseModel): Twilio also work here. You cannot, for example, spoof messages from a private cell phone number. If you are using `messaging_service_sid`, this parameter must be empty. - """ # noqa: E501 + """ @validator("client", pre=True, always=True) def set_validator(cls, values: dict) -> dict: @@ -60,7 +60,7 @@ class TwilioAPIWrapper(BaseModel): SMS/MMS or [Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses) for other 3rd-party channels. - """ # noqa: E501 + """ message = self.client.messages.create(to, from_=self.from_number, body=body) return message.sid diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index 7984d7b3b1..06f276053a 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -26,4 +26,5 @@ class TwilioProvider(BuiltinToolProviderController): except KeyError as e: raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/twilio/twilio.yaml b/api/core/tools/provider/builtin/twilio/twilio.yaml index b5143c8736..21867c1da5 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.yaml +++ b/api/core/tools/provider/builtin/twilio/twilio.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: 通过SMS或Twilio消息通道发送消息。 pt_BR: Send messages through SMS or Twilio Messaging Channels. icon: icon.svg + tags: + - social credentials_for_provider: account_sid: type: secret-input diff --git a/api/core/tools/provider/builtin/vanna/_assets/icon.png b/api/core/tools/provider/builtin/vanna/_assets/icon.png new file mode 100644 index 0000000000..3a9011b54d Binary files /dev/null and b/api/core/tools/provider/builtin/vanna/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py new file mode 100644 index 0000000000..a6efb0f79a --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -0,0 +1,129 @@ +from typing import Any, Union + +from vanna.remote import VannaDefault + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + + +class VannaTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + api_key = self.runtime.credentials.get("api_key", None) + if not api_key: + raise ToolProviderCredentialValidationError("Please input api key") + + model = tool_parameters.get("model", "") + if not model: + return self.create_text_message("Please input RAG model") + + prompt = tool_parameters.get("prompt", "") + if not prompt: + return self.create_text_message("Please input prompt") + + url = tool_parameters.get("url", "") + if not url: + return self.create_text_message("Please input URL/Host/DSN") + + db_name = tool_parameters.get("db_name", "") + username = tool_parameters.get("username", "") + password = tool_parameters.get("password", "") + port = tool_parameters.get("port", 0) + + vn = VannaDefault(model=model, api_key=api_key) + + db_type = tool_parameters.get("db_type", "") + if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]: + if not db_name: + return self.create_text_message("Please input database name") + if not username: + return self.create_text_message("Please input username") + if port < 1: + return self.create_text_message("Please input port") + + schema_sql = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS" + match db_type: + case "SQLite": + schema_sql = "SELECT type, sql FROM sqlite_master WHERE sql is not null" + vn.connect_to_sqlite(url) + case "Postgres": + vn.connect_to_postgres(host=url, dbname=db_name, user=username, password=password, port=port) + case "DuckDB": + vn.connect_to_duckdb(url=url) + case "SQLServer": + vn.connect_to_mssql(url) + case "MySQL": + vn.connect_to_mysql(host=url, dbname=db_name, user=username, password=password, port=port) + case "Oracle": + vn.connect_to_oracle(user=username, password=password, dsn=url) + case "Hive": + vn.connect_to_hive(host=url, dbname=db_name, user=username, password=password, port=port) + case "ClickHouse": + vn.connect_to_clickhouse(host=url, dbname=db_name, user=username, password=password, port=port) + + enable_training = tool_parameters.get("enable_training", False) + reset_training_data = tool_parameters.get("reset_training_data", False) + if enable_training: + if reset_training_data: + existing_training_data = vn.get_training_data() + if len(existing_training_data) > 0: + for _, training_data in existing_training_data.iterrows(): + vn.remove_training_data(training_data["id"]) + + ddl = tool_parameters.get("ddl", "") + question = tool_parameters.get("question", "") + sql = tool_parameters.get("sql", "") + memos = tool_parameters.get("memos", "") + training_metadata = tool_parameters.get("training_metadata", False) + + if training_metadata: + if db_type == "SQLite": + df_ddl = vn.run_sql(schema_sql) + for ddl in df_ddl["sql"].to_list(): + vn.train(ddl=ddl) + else: + df_information_schema = vn.run_sql(schema_sql) + plan = vn.get_training_plan_generic(df_information_schema) + vn.train(plan=plan) + + if ddl: + vn.train(ddl=ddl) + + if sql: + if question: + vn.train(question=question, sql=sql) + else: + vn.train(sql=sql) + if memos: + vn.train(documentation=memos) + + ######################################################################################### + # Due to CVE-2024-5565, we have to disable the chart generation feature + # The Vanna library uses a prompt function to present the user with visualized results, + # it is possible to alter the prompt using prompt injection and run arbitrary Python code + # instead of the intended visualization code. + # Specifically - allowing external input to the library’s “ask” method + # with "visualize" set to True (default behavior) leads to remote code execution. + # Affected versions: <= 0.5.5 + ######################################################################################### + generate_chart = False + # generate_chart = tool_parameters.get("generate_chart", True) + res = vn.ask(prompt, False, True, generate_chart) + + result = [] + + if res is not None: + result.append(self.create_text_message(res[0])) + if len(res) > 1 and res[1] is not None: + result.append(self.create_text_message(res[1].to_markdown())) + if len(res) > 2 and res[2] is not None: + result.append( + self.create_blob_message(blob=res[2].to_image(format="svg"), meta={"mime_type": "image/svg+xml"}) + ) + + return result diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.yaml b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml new file mode 100644 index 0000000000..ae2eae94c4 --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml @@ -0,0 +1,213 @@ +identity: + name: vanna + author: QCTC + label: + en_US: Vanna.AI + zh_Hans: Vanna.AI +description: + human: + en_US: The fastest way to get actionable insights from your database just by asking questions. + zh_Hans: 一个基于大模型和RAG的Text2SQL工具。 + llm: A tool for converting text to SQL. +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: used for generating SQL + zh_Hans: 用于生成SQL + llm_description: key words for generating SQL + form: llm + - name: model + type: string + required: true + label: + en_US: RAG Model + zh_Hans: RAG模型 + human_description: + en_US: RAG Model for your database DDL + zh_Hans: 存储数据库训练数据的RAG模型 + llm_description: RAG Model for generating SQL + form: form + - name: db_type + type: select + required: true + options: + - value: SQLite + label: + en_US: SQLite + zh_Hans: SQLite + - value: Postgres + label: + en_US: Postgres + zh_Hans: Postgres + - value: DuckDB + label: + en_US: DuckDB + zh_Hans: DuckDB + - value: SQLServer + label: + en_US: Microsoft SQL Server + zh_Hans: 微软 SQL Server + - value: MySQL + label: + en_US: MySQL + zh_Hans: MySQL + - value: Oracle + label: + en_US: Oracle + zh_Hans: Oracle + - value: Hive + label: + en_US: Hive + zh_Hans: Hive + - value: ClickHouse + label: + en_US: ClickHouse + zh_Hans: ClickHouse + default: SQLite + label: + en_US: DB Type + zh_Hans: 数据库类型 + human_description: + en_US: Database type. + zh_Hans: 选择要链接的数据库类型。 + form: form + - name: url + type: string + required: true + label: + en_US: URL/Host/DSN + zh_Hans: URL/Host/DSN + human_description: + en_US: Please input depending on DB type, visit https://vanna.ai/docs/ for more specification + zh_Hans: 请根据数据库类型,填入对应值,详情参考https://vanna.ai/docs/ + form: form + - name: db_name + type: string + required: false + label: + en_US: DB name + zh_Hans: 数据库名 + human_description: + en_US: Database name + zh_Hans: 数据库名 + form: form + - name: username + type: string + required: false + label: + en_US: Username + zh_Hans: 用户名 + human_description: + en_US: Username + zh_Hans: 用户名 + form: form + - name: password + type: secret-input + required: false + label: + en_US: Password + zh_Hans: 密码 + human_description: + en_US: Password + zh_Hans: 密码 + form: form + - name: port + type: number + required: false + label: + en_US: Port + zh_Hans: 端口 + human_description: + en_US: Port + zh_Hans: 端口 + form: form + - name: ddl + type: string + required: false + label: + en_US: Training DDL + zh_Hans: 训练DDL + human_description: + en_US: DDL statements for training data + zh_Hans: 用于训练RAG Model的建表语句 + form: form + - name: question + type: string + required: false + label: + en_US: Training Question + zh_Hans: 训练问题 + human_description: + en_US: Question-SQL Pairs + zh_Hans: Question-SQL中的问题 + form: form + - name: sql + type: string + required: false + label: + en_US: Training SQL + zh_Hans: 训练SQL + human_description: + en_US: SQL queries to your training data + zh_Hans: 用于训练RAG Model的SQL语句 + form: form + - name: memos + type: string + required: false + label: + en_US: Training Memos + zh_Hans: 训练说明 + human_description: + en_US: Sometimes you may want to add documentation about your business terminology or definitions + zh_Hans: 添加更多关于数据库的业务说明 + form: form + - name: enable_training + type: boolean + required: false + default: false + label: + en_US: Training Data + zh_Hans: 训练数据 + human_description: + en_US: You only need to train once. Do not train again unless you want to add more training data + zh_Hans: 训练数据无更新时,训练一次即可 + form: form + - name: reset_training_data + type: boolean + required: false + default: false + label: + en_US: Reset Training Data + zh_Hans: 重置训练数据 + human_description: + en_US: Remove all training data in the current RAG Model + zh_Hans: 删除当前RAG Model中的所有训练数据 + form: form + - name: training_metadata + type: boolean + required: false + default: false + label: + en_US: Training Metadata + zh_Hans: 训练元数据 + human_description: + en_US: If enabled, it will attempt to train on the metadata of that database + zh_Hans: 是否自动从数据库获取元数据来训练 + form: form + - name: generate_chart + type: boolean + required: false + default: True + label: + en_US: Generate Charts + zh_Hans: 生成图表 + human_description: + en_US: Generate Charts + zh_Hans: 是否生成图表 + form: form diff --git a/api/core/tools/provider/builtin/vanna/vanna.py b/api/core/tools/provider/builtin/vanna/vanna.py new file mode 100644 index 0000000000..ab1fd71df5 --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/vanna.py @@ -0,0 +1,25 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.vanna.tools.vanna import VannaTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class VannaProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + VannaTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "model": "chinook", + "db_type": "SQLite", + "url": "https://vanna.ai/Chinook.sqlite", + "query": "What are the top 10 customers by sales?" + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vanna/vanna.yaml b/api/core/tools/provider/builtin/vanna/vanna.yaml new file mode 100644 index 0000000000..b29fa103e1 --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/vanna.yaml @@ -0,0 +1,25 @@ +identity: + author: QCTC + name: vanna + label: + en_US: Vanna.AI + zh_Hans: Vanna.AI + description: + en_US: The fastest way to get actionable insights from your database just by asking questions. + zh_Hans: 一个基于大模型和RAG的Text2SQL工具。 + icon: icon.png +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: API key + zh_Hans: API key + placeholder: + en_US: Please input your API key + zh_Hans: 请输入你的 API key + pt_BR: Please input your API key + help: + en_US: Get your API key from Vanna.AI + zh_Hans: 从 Vanna.AI 获取你的 API key + url: https://vanna.ai/account/profile diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py index df996b5283..c6ec198034 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -71,6 +71,4 @@ class VectorizerTool(BuiltinTool): options=[i.name for i in self.list_default_image_variables()] ) ] - - def is_tool_available(self) -> bool: - return len(self.list_default_image_variables()) > 0 \ No newline at end of file + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 2b4d71e058..3f89a83500 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -9,7 +9,7 @@ class VectorizerProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: VectorizerTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -20,4 +20,5 @@ class VectorizerProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml b/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml index 07a20380e9..1257f8d285 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml @@ -10,6 +10,9 @@ identity: zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。 pt_BR: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. icon: icon.png + tags: + - productivity + - image credentials_for_provider: api_key_name: type: secret-input diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py index 5e8c405b47..3d098e6768 100644 --- a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py @@ -7,9 +7,9 @@ from core.tools.tool.builtin_tool import BuiltinTool class WebscraperTool(BuiltinTool): def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools """ @@ -18,12 +18,15 @@ class WebscraperTool(BuiltinTool): user_agent = tool_parameters.get('user_agent', '') if not url: return self.create_text_message('Please input url') - + # get webpage result = self.get_url(url, user_agent=user_agent) - # summarize and return - return self.create_text_message(self.summary(user_id=user_id, content=result)) + if tool_parameters.get('generate_summary'): + # summarize and return + return self.create_text_message(self.summary(user_id=user_id, content=result)) + else: + # return full webpage + return self.create_text_message(result) except Exception as e: raise ToolInvokeError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.yaml b/api/core/tools/provider/builtin/webscraper/tools/webscraper.yaml index 5782dbb0c7..180cfec6fc 100644 --- a/api/core/tools/provider/builtin/webscraper/tools/webscraper.yaml +++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.yaml @@ -38,3 +38,23 @@ parameters: pt_BR: used for identifying the browser. form: form default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36 + - name: generate_summary + type: boolean + required: false + label: + en_US: Whether to generate summary + zh_Hans: 是否生成摘要 + human_description: + en_US: If true, the crawler will only return the page summary content. + zh_Hans: 如果启用,爬虫将仅返回页面摘要内容。 + form: form + options: + - value: true + label: + en_US: Yes + zh_Hans: 是 + - value: false + label: + en_US: No + zh_Hans: 否 + default: false diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.py b/api/core/tools/provider/builtin/webscraper/webscraper.py index 8761493e3b..1e60fdb293 100644 --- a/api/core/tools/provider/builtin/webscraper/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/webscraper.py @@ -9,7 +9,7 @@ class WebscraperProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: WebscraperTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -20,4 +20,5 @@ class WebscraperProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.yaml b/api/core/tools/provider/builtin/webscraper/webscraper.yaml index 056108f036..6c2eb97784 100644 --- a/api/core/tools/provider/builtin/webscraper/webscraper.yaml +++ b/api/core/tools/provider/builtin/webscraper/webscraper.yaml @@ -10,4 +10,6 @@ identity: zh_Hans: 一个用于抓取网页的工具。 pt_BR: Web Scrapper tool kit is used to scrape web icon: icon.svg + tags: + - productivity credentials_for_provider: diff --git a/api/core/tools/provider/builtin/wecom/wecom.py b/api/core/tools/provider/builtin/wecom/wecom.py index 7a2576b668..573f76ee56 100644 --- a/api/core/tools/provider/builtin/wecom/wecom.py +++ b/api/core/tools/provider/builtin/wecom/wecom.py @@ -5,4 +5,3 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class WecomProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: WecomGroupBotTool() - pass diff --git a/api/core/tools/provider/builtin/wecom/wecom.yaml b/api/core/tools/provider/builtin/wecom/wecom.yaml index 39d00032a0..a544055ba4 100644 --- a/api/core/tools/provider/builtin/wecom/wecom.yaml +++ b/api/core/tools/provider/builtin/wecom/wecom.yaml @@ -10,4 +10,6 @@ identity: zh_Hans: 企业微信群机器人 pt_BR: Wecom group bot icon: icon.png + tags: + - social credentials_for_provider: diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.py b/api/core/tools/provider/builtin/wikipedia/wikipedia.py index 8d53852255..f8038714a5 100644 --- a/api/core/tools/provider/builtin/wikipedia/wikipedia.py +++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.py @@ -7,7 +7,7 @@ class WikiPediaProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: WikiPediaSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -17,4 +17,5 @@ class WikiPediaProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.yaml b/api/core/tools/provider/builtin/wikipedia/wikipedia.yaml index f4b5da8947..c582824022 100644 --- a/api/core/tools/provider/builtin/wikipedia/wikipedia.yaml +++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.yaml @@ -10,4 +10,6 @@ identity: zh_Hans: 维基百科是一个由全世界的志愿者创建和编辑的免费在线百科全书。 pt_BR: Wikipedia is a free online encyclopedia, created and edited by volunteers around the world. icon: icon.svg + tags: + - social credentials_for_provider: diff --git a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py index 7512710515..8cb9c10ddf 100644 --- a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py @@ -48,7 +48,7 @@ class WolframAlphaTool(BuiltinTool): if 'success' not in response_data['queryresult'] or response_data['queryresult']['success'] != True: query_result = response_data.get('queryresult', {}) - if 'error' in query_result and query_result['error']: + if query_result.get('error'): if 'msg' in query_result['error']: if query_result['error']['msg'] == 'Invalid appid': raise ToolProviderCredentialValidationError('Invalid appid') diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py index 4e8213d90c..ef1aac7ff2 100644 --- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py @@ -9,7 +9,7 @@ class GoogleProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: WolframAlphaTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -19,4 +19,5 @@ class GoogleProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.yaml b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.yaml index e4cc465712..91265eb3c0 100644 --- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.yaml +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.yaml @@ -10,6 +10,9 @@ identity: zh_Hans: WolframAlpha 是一个强大的计算知识引擎。 pt_BR: WolframAlpha is a powerful computational knowledge engine. icon: icon.svg + tags: + - productivity + - utilities credentials_for_provider: appid: type: secret-input diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.py b/api/core/tools/provider/builtin/yahoo/yahoo.py index ade33ffb63..96dbc6c3d0 100644 --- a/api/core/tools/provider/builtin/yahoo/yahoo.py +++ b/api/core/tools/provider/builtin/yahoo/yahoo.py @@ -7,7 +7,7 @@ class YahooFinanceProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: YahooFinanceSearchTickerTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -17,4 +17,5 @@ class YahooFinanceProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.yaml b/api/core/tools/provider/builtin/yahoo/yahoo.yaml index b16517eaf9..f1e82952c0 100644 --- a/api/core/tools/provider/builtin/yahoo/yahoo.yaml +++ b/api/core/tools/provider/builtin/yahoo/yahoo.yaml @@ -10,4 +10,7 @@ identity: zh_Hans: 雅虎财经,获取并整理出最新的新闻、股票报价等一切你想要的财经信息。 pt_BR: Finance, and Yahoo! get the latest news, stock quotes, and interactive chart with Yahoo! icon: icon.png + tags: + - business + - finance credentials_for_provider: diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 86160dfa6c..7a9b9fce4a 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -36,7 +36,7 @@ class YoutubeVideosAnalyticsTool(BuiltinTool): youtube = build('youtube', 'v3', developerKey=self.runtime.credentials['google_api_key']) # try to get channel id - search_results = youtube.search().list(q='mrbeast', type='channel', order='relevance', part='id').execute() + search_results = youtube.search().list(q=channel, type='channel', order='relevance', part='id').execute() channel_id = search_results['items'][0]['id']['channelId'] start_date, end_date = time_range diff --git a/api/core/tools/provider/builtin/youtube/youtube.py b/api/core/tools/provider/builtin/youtube/youtube.py index 8cca578c46..83a4fccb32 100644 --- a/api/core/tools/provider/builtin/youtube/youtube.py +++ b/api/core/tools/provider/builtin/youtube/youtube.py @@ -7,7 +7,7 @@ class YahooFinanceProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: YoutubeVideosAnalyticsTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -19,4 +19,5 @@ class YahooFinanceProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/youtube.yaml b/api/core/tools/provider/builtin/youtube/youtube.yaml index 2f83ae43ee..d6915b9a32 100644 --- a/api/core/tools/provider/builtin/youtube/youtube.yaml +++ b/api/core/tools/provider/builtin/youtube/youtube.yaml @@ -10,6 +10,8 @@ identity: zh_Hans: YouTube(油管)是全球最大的视频分享网站,用户可以在上面上传、观看和分享视频。 pt_BR: YouTube é o maior site de compartilhamento de vídeos do mundo, onde os usuários podem fazer upload, assistir e compartilhar vídeos. icon: icon.svg + tags: + - videos credentials_for_provider: google_api_key: type: secret-input diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index c2178cdd40..d076cb384f 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -2,25 +2,24 @@ from abc import abstractmethod from os import listdir, path from typing import Any -from yaml import FullLoader, load - +from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType -from core.tools.entities.user_entities import UserToolProviderCredentials +from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolNotFoundError, ToolParameterValidationError, - ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool -from core.utils.module_import_helper import load_single_subclass_from_source +from core.tools.utils.tool_parameter_converter import ToolParameterConverter +from core.tools.utils.yaml_utils import load_yaml_file class BuiltinToolProviderController(ToolProviderController): def __init__(self, **data: Any) -> None: - if self.app_type == ToolProviderType.API_BASED or self.app_type == ToolProviderType.APP_BASED: + if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: super().__init__(**data) return @@ -28,10 +27,9 @@ class BuiltinToolProviderController(ToolProviderController): provider = self.__class__.__module__.split('.')[-1] yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') try: - with open(yaml_path, 'rb') as f: - provider_yaml = load(f.read(), FullLoader) - except: - raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}') + provider_yaml = load_yaml_file(yaml_path) + except Exception as e: + raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}') if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None: # set credentials name @@ -58,18 +56,18 @@ class BuiltinToolProviderController(ToolProviderController): tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path))) tools = [] for tool_file in tool_files: - with open(path.join(tool_path, tool_file), encoding='utf-8') as f: - # get tool name - tool_name = tool_file.split(".")[0] - tool = load(f.read(), FullLoader) - # get tool class, import the module - assistant_tool_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'builtin', provider, 'tools', f'{tool_name}.py'), - parent_type=BuiltinTool) - tool["identity"]["provider"] = provider - tools.append(assistant_tool_class(**tool)) + # get tool name + tool_name = tool_file.split(".")[0] + tool = load_yaml_file(path.join(tool_path, tool_file)) + + # get tool class, import the module + assistant_tool_class = load_single_subclass_from_source( + module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', + script_path=path.join(path.dirname(path.realpath(__file__)), + 'builtin', provider, 'tools', f'{tool_name}.py'), + parent_type=BuiltinTool) + tool["identity"]["provider"] = provider + tools.append(assistant_tool_class(**tool)) self.tools = tools return tools @@ -84,15 +82,6 @@ class BuiltinToolProviderController(ToolProviderController): return {} return self.credentials_schema.copy() - - def user_get_credentials_schema(self) -> UserToolProviderCredentials: - """ - returns the credentials schema of the provider, this method is used for user - - :return: the credentials schema - """ - credentials = self.credentials_schema.copy() - return UserToolProviderCredentials(credentials=credentials) def get_tools(self) -> list[Tool]: """ @@ -131,7 +120,7 @@ class BuiltinToolProviderController(ToolProviderController): len(self.credentials_schema) != 0 @property - def app_type(self) -> ToolProviderType: + def provider_type(self) -> ToolProviderType: """ returns the type of the provider @@ -139,6 +128,22 @@ class BuiltinToolProviderController(ToolProviderController): """ return ToolProviderType.BUILT_IN + @property + def tool_labels(self) -> list[str]: + """ + returns the labels of the provider + + :return: labels of the provider + """ + label_enums = self._get_tool_labels() + return [default_tool_label_dict[label].name for label in label_enums] + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + """ + returns the labels of the provider + """ + return self.identity.tags or [] + def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: """ validate the parameters of the tool and set the default value if needed @@ -196,91 +201,9 @@ class BuiltinToolProviderController(ToolProviderController): # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - default_value = parameter_schema.default - # parse default value into the correct type - if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \ - parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - default_value = str(default_value) - elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - default_value = float(default_value) - elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - default_value = bool(default_value) - + default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, + parameter_schema.type) tool_parameters[parameter] = default_value - - def validate_credentials_format(self, credentials: dict[str, Any]) -> None: - """ - validate the format of the credentials of the provider and set the default value if needed - - :param credentials: the credentials of the tool - """ - credentials_schema = self.credentials_schema - if credentials_schema is None: - return - - credentials_need_to_validate: dict[str, ToolProviderCredentials] = {} - for credential_name in credentials_schema: - credentials_need_to_validate[credential_name] = credentials_schema[credential_name] - - for credential_name in credentials: - if credential_name not in credentials_need_to_validate: - raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}') - - # check type - credential_schema = credentials_need_to_validate[credential_name] - if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT: - if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be string') - - elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: - if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be string') - - options = credential_schema.options - if not isinstance(options, list): - raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} options should be list') - - if credentials[credential_name] not in [x.value for x in options]: - raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}') - elif credential_schema.type == ToolProviderCredentials.CredentialsType.BOOLEAN: - if isinstance(credentials[credential_name], bool): - pass - elif isinstance(credentials[credential_name], str): - if credentials[credential_name].lower() == 'true': - credentials[credential_name] = True - elif credentials[credential_name].lower() == 'false': - credentials[credential_name] = False - else: - raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean') - elif isinstance(credentials[credential_name], int): - if credentials[credential_name] == 1: - credentials[credential_name] = True - elif credentials[credential_name] == 0: - credentials[credential_name] = False - else: - raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean') - else: - raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean') - - if credentials[credential_name] or credentials[credential_name] == False: - credentials_need_to_validate.pop(credential_name) - - for credential_name in credentials_need_to_validate: - credential_schema = credentials_need_to_validate[credential_name] - if credential_schema.required: - raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} is required') - - # the credential is not set currently, set the default value if needed - if credential_schema.default is not None: - default_value = credential_schema.default - # parse default value into the correct type - if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: - default_value = str(default_value) - - credentials[credential_name] = default_value def validate_credentials(self, credentials: dict[str, Any]) -> None: """ diff --git a/api/core/tools/provider/model_tool_provider.py b/api/core/tools/provider/model_tool_provider.py deleted file mode 100644 index ef47e9aae9..0000000000 --- a/api/core/tools/provider/model_tool_provider.py +++ /dev/null @@ -1,244 +0,0 @@ -from copy import deepcopy -from typing import Any - -from core.entities.model_entities import ModelStatus -from core.errors.error import ProviderTokenNotInitError -from core.model_manager import ModelInstance -from core.model_runtime.entities.model_entities import ModelFeature, ModelType -from core.provider_manager import ProviderConfiguration, ProviderManager, ProviderModelBundle -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ModelToolPropertyKey, - ToolDescription, - ToolIdentity, - ToolParameter, - ToolProviderCredentials, - ToolProviderIdentity, - ToolProviderType, -) -from core.tools.errors import ToolNotFoundError -from core.tools.provider.tool_provider import ToolProviderController -from core.tools.tool.model_tool import ModelTool -from core.tools.tool.tool import Tool -from core.tools.utils.configuration import ModelToolConfigurationManager - - -class ModelToolProviderController(ToolProviderController): - configuration: ProviderConfiguration = None - is_active: bool = False - - def __init__(self, configuration: ProviderConfiguration = None, **kwargs): - """ - init the provider - - :param data: the data of the provider - """ - super().__init__(**kwargs) - self.configuration = configuration - - @staticmethod - def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController': - """ - init the provider from db - - :param configuration: the configuration of the provider - """ - # check if all models are active - if configuration is None: - return None - is_active = True - models = configuration.get_provider_models() - for model in models: - if model.status != ModelStatus.ACTIVE: - is_active = False - break - - # get the provider configuration - model_tool_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider) - if model_tool_configuration is None: - raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}') - - # override the configuration - if model_tool_configuration.label: - label = deepcopy(model_tool_configuration.label) - if label.en_US: - label.en_US = model_tool_configuration.label.en_US - if label.zh_Hans: - label.zh_Hans = model_tool_configuration.label.zh_Hans - else: - label = I18nObject( - en_US=configuration.provider.label.en_US, - zh_Hans=configuration.provider.label.zh_Hans - ) - - return ModelToolProviderController( - is_active=is_active, - identity=ToolProviderIdentity( - author='Dify', - name=configuration.provider.provider, - description=I18nObject( - zh_Hans=f'{label.zh_Hans} 模型能力提供商', - en_US=f'{label.en_US} model capability provider' - ), - label=I18nObject( - zh_Hans=label.zh_Hans, - en_US=label.en_US - ), - icon=configuration.provider.icon_small.en_US, - ), - configuration=configuration, - credentials_schema={}, - ) - - @staticmethod - def is_configuration_valid(configuration: ProviderConfiguration) -> bool: - """ - check if the configuration has a model can be used as a tool - """ - models = configuration.get_provider_models() - for model in models: - if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []): - return True - return False - - def _get_model_tools(self, tenant_id: str = None) -> list[ModelTool]: - """ - returns a list of tools that the provider can provide - - :return: list of tools - """ - tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff' - provider_manager = ProviderManager() - if self.configuration is None: - configurations = provider_manager.get_configurations(tenant_id=tenant_id).values() - self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None) - # get all tools - tools: list[ModelTool] = [] - # get all models - if not self.configuration: - return tools - configuration = self.configuration - - provider_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider) - if provider_configuration is None: - raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}') - - for model in configuration.get_provider_models(): - model_configuration = ModelToolConfigurationManager.get_model_configuration(self.configuration.provider.provider, model.model) - if model_configuration is None: - continue - - if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []): - provider_instance = configuration.get_provider_instance() - model_type_instance = provider_instance.get_model_instance(model.model_type) - provider_model_bundle = ProviderModelBundle( - configuration=configuration, - provider_instance=provider_instance, - model_type_instance=model_type_instance - ) - - try: - model_instance = ModelInstance(provider_model_bundle, model.model) - except ProviderTokenNotInitError: - model_instance = None - - tools.append(ModelTool( - identity=ToolIdentity( - author='Dify', - name=model.model, - label=model_configuration.label, - ), - parameters=[ - ToolParameter( - name=ModelToolPropertyKey.IMAGE_PARAMETER_NAME.value, - label=I18nObject(zh_Hans='图片ID', en_US='Image ID'), - human_description=I18nObject(zh_Hans='图片ID', en_US='Image ID'), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - required=True, - default=Tool.VARIABLE_KEY.IMAGE.value - ) - ], - description=ToolDescription( - human=I18nObject(zh_Hans='图生文工具', en_US='Convert image to text'), - llm='Vision tool used to extract text and other visual information from images, can be used for OCR, image captioning, etc.', - ), - is_team_authorization=model.status == ModelStatus.ACTIVE, - tool_type=ModelTool.ModelToolType.VISION, - model_instance=model_instance, - model=model.model, - )) - - self.tools = tools - return tools - - def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: - """ - returns the credentials schema of the provider - - :return: the credentials schema - """ - return {} - - def get_tools(self, user_id: str, tenant_id: str) -> list[ModelTool]: - """ - returns a list of tools that the provider can provide - - :return: list of tools - """ - return self._get_model_tools(tenant_id=tenant_id) - - def get_tool(self, tool_name: str) -> ModelTool: - """ - get tool by name - - :param tool_name: the name of the tool - :return: the tool - """ - if self.tools is None: - self.get_tools(user_id='', tenant_id=self.configuration.tenant_id) - - for tool in self.tools: - if tool.identity.name == tool_name: - return tool - - raise ValueError(f'tool {tool_name} not found') - - def get_parameters(self, tool_name: str) -> list[ToolParameter]: - """ - returns the parameters of the tool - - :param tool_name: the name of the tool, defined in `get_tools` - :return: list of parameters - """ - tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) - if tool is None: - raise ToolNotFoundError(f'tool {tool_name} not found') - return tool.parameters - - @property - def app_type(self) -> ToolProviderType: - """ - returns the type of the provider - - :return: type of the provider - """ - return ToolProviderType.MODEL - - def validate_credentials(self, credentials: dict[str, Any]) -> None: - """ - validate the credentials of the provider - - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool - """ - pass - - def _validate_credentials(self, credentials: dict[str, Any]) -> None: - """ - validate the credentials of the provider - - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool - """ - pass \ No newline at end of file diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index b527f2b274..ef1ace9c7c 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -9,9 +9,9 @@ from core.tools.entities.tool_entities import ( ToolProviderIdentity, ToolProviderType, ) -from core.tools.entities.user_entities import UserToolProviderCredentials from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.tool import Tool +from core.tools.utils.tool_parameter_converter import ToolParameterConverter class ToolProviderController(BaseModel, ABC): @@ -27,15 +27,6 @@ class ToolProviderController(BaseModel, ABC): """ return self.credentials_schema.copy() - def user_get_credentials_schema(self) -> UserToolProviderCredentials: - """ - returns the credentials schema of the provider, this method is used for user - - :return: the credentials schema - """ - credentials = self.credentials_schema.copy() - return UserToolProviderCredentials(credentials=credentials) - @abstractmethod def get_tools(self) -> list[Tool]: """ @@ -67,7 +58,7 @@ class ToolProviderController(BaseModel, ABC): return tool.parameters @property - def app_type(self) -> ToolProviderType: + def provider_type(self) -> ToolProviderType: """ returns the type of the provider @@ -132,17 +123,8 @@ class ToolProviderController(BaseModel, ABC): # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - default_value = parameter_schema.default - # parse default value into the correct type - if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \ - parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - default_value = str(default_value) - elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - default_value = float(default_value) - elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - default_value = bool(default_value) - - tool_parameters[parameter] = default_value + tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, + parameter_schema.type) def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ @@ -197,26 +179,4 @@ class ToolProviderController(BaseModel, ABC): default_value = str(default_value) credentials[credential_name] = default_value - - def validate_credentials(self, credentials: dict[str, Any]) -> None: - """ - validate the credentials of the provider - - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool - """ - # validate credentials format - self.validate_credentials_format(credentials) - - # validate credentials - self._validate_credentials(credentials) - - @abstractmethod - def _validate_credentials(self, credentials: dict[str, Any]) -> None: - """ - validate the credentials of the provider - - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool - """ - pass \ No newline at end of file + \ No newline at end of file diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py new file mode 100644 index 0000000000..f98ad0f26a --- /dev/null +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -0,0 +1,230 @@ +from typing import Optional + +from core.app.app_config.entities import VariableEntity +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.model_runtime.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolDescription, + ToolIdentity, + ToolParameter, + ToolParameterOption, + ToolProviderType, +) +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.workflow_tool import WorkflowTool +from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from extensions.ext_database import db +from models.model import App, AppMode +from models.tools import WorkflowToolProvider +from models.workflow import Workflow + + +class WorkflowToolProviderController(ToolProviderController): + provider_id: str + + @classmethod + def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController': + app = db_provider.app + + if not app: + raise ValueError('app not found') + + controller = WorkflowToolProviderController(**{ + 'identity': { + 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', + 'name': db_provider.label, + 'label': { + 'en_US': db_provider.label, + 'zh_Hans': db_provider.label + }, + 'description': { + 'en_US': db_provider.description, + 'zh_Hans': db_provider.description + }, + 'icon': db_provider.icon, + }, + 'credentials_schema': {}, + 'provider_id': db_provider.id or '', + }) + + # init tools + + controller.tools = [controller._get_db_provider_tool(db_provider, app)] + + return controller + + @property + def provider_type(self) -> ToolProviderType: + return ToolProviderType.WORKFLOW + + def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: + """ + get db provider tool + :param db_provider: the db provider + :param app: the app + :return: the tool + """ + workflow: Workflow = db.session.query(Workflow).filter( + Workflow.app_id == db_provider.app_id, + Workflow.version == db_provider.version + ).first() + if not workflow: + raise ValueError('workflow not found') + + # fetch start node + graph: dict = workflow.graph_dict + features_dict: dict = workflow.features_dict + features = WorkflowAppConfigManager.convert_features( + config_dict=features_dict, + app_mode=AppMode.WORKFLOW + ) + + parameters = db_provider.parameter_configurations + variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) + + def fetch_workflow_variable(variable_name: str) -> VariableEntity: + return next(filter(lambda x: x.variable == variable_name, variables), None) + + user = db_provider.user + + workflow_tool_parameters = [] + for parameter in parameters: + variable = fetch_workflow_variable(parameter.name) + if variable: + parameter_type = None + options = None + if variable.type in [ + VariableEntity.Type.TEXT_INPUT, + VariableEntity.Type.PARAGRAPH, + ]: + parameter_type = ToolParameter.ToolParameterType.STRING + elif variable.type in [ + VariableEntity.Type.SELECT + ]: + parameter_type = ToolParameter.ToolParameterType.SELECT + elif variable.type in [ + VariableEntity.Type.NUMBER + ]: + parameter_type = ToolParameter.ToolParameterType.NUMBER + else: + raise ValueError(f'unsupported variable type {variable.type}') + + if variable.type == VariableEntity.Type.SELECT and variable.options: + options = [ + ToolParameterOption( + value=option, + label=I18nObject( + en_US=option, + zh_Hans=option + ) + ) for option in variable.options + ] + + workflow_tool_parameters.append( + ToolParameter( + name=parameter.name, + label=I18nObject( + en_US=variable.label, + zh_Hans=variable.label + ), + human_description=I18nObject( + en_US=parameter.description, + zh_Hans=parameter.description + ), + type=parameter_type, + form=parameter.form, + llm_description=parameter.description, + required=variable.required, + options=options, + default=variable.default + ) + ) + elif features.file_upload: + workflow_tool_parameters.append( + ToolParameter( + name=parameter.name, + label=I18nObject( + en_US=parameter.name, + zh_Hans=parameter.name + ), + human_description=I18nObject( + en_US=parameter.description, + zh_Hans=parameter.description + ), + type=ToolParameter.ToolParameterType.FILE, + llm_description=parameter.description, + required=False, + form=parameter.form, + ) + ) + else: + raise ValueError('variable not found') + + return WorkflowTool( + identity=ToolIdentity( + author=user.name if user else '', + name=db_provider.name, + label=I18nObject( + en_US=db_provider.label, + zh_Hans=db_provider.label + ), + provider=self.provider_id, + icon=db_provider.icon, + ), + description=ToolDescription( + human=I18nObject( + en_US=db_provider.description, + zh_Hans=db_provider.description + ), + llm=db_provider.description, + ), + parameters=workflow_tool_parameters, + is_team_authorization=True, + workflow_app_id=app.id, + workflow_entities={ + 'app': app, + 'workflow': workflow, + }, + version=db_provider.version, + workflow_call_depth=0, + label=db_provider.label + ) + + def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: + """ + fetch tools from database + + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools + """ + if self.tools is not None: + return self.tools + + db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.app_id == self.provider_id, + ).first() + + if not db_providers: + return [] + + self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] + + return self.tools + + def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: + """ + get tool by name + + :param tool_name: the name of the tool + :return: the tool + """ + if self.tools is None: + return None + + for tool in self.tools: + if tool.identity.name == tool_name: + return tool + + return None diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index f7b963a92e..ff7d4015ab 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -8,9 +8,8 @@ import httpx import requests import core.helper.ssrf_proxy as ssrf_proxy -from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType -from core.tools.entities.user_entities import UserToolProvider from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.tool import Tool @@ -20,12 +19,12 @@ API_TOOL_DEFAULT_TIMEOUT = ( ) class ApiTool(Tool): - api_bundle: ApiBasedToolBundle + api_bundle: ApiToolBundle """ Api tool """ - def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': """ fork a new tool with meta data @@ -37,7 +36,7 @@ class ApiTool(Tool): parameters=self.parameters.copy() if self.parameters else None, description=self.description.copy() if self.description else None, api_bundle=self.api_bundle.copy() if self.api_bundle else None, - runtime=Tool.Runtime(**meta) + runtime=Tool.Runtime(**runtime) ) def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str: @@ -55,7 +54,7 @@ class ApiTool(Tool): return self.validate_and_parse_response(response) def tool_provider_type(self) -> ToolProviderType: - return UserToolProvider.ProviderType.API + return ToolProviderType.API def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers = {} diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index 68193e5f69..1ed05c112f 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -2,9 +2,8 @@ from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.entities.tool_entities import ToolProviderType -from core.tools.entities.user_entities import UserToolProvider -from core.tools.model.tool_model_manager import ToolModelManager from core.tools.tool.tool import Tool +from core.tools.utils.model_invocation_utils import ModelInvocationUtils from core.tools.utils.web_reader_tool import get_url _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language @@ -34,7 +33,7 @@ class BuiltinTool(Tool): :return: the model result """ # invoke model - return ToolModelManager.invoke( + return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id, tool_type='builtin', @@ -43,7 +42,7 @@ class BuiltinTool(Tool): ) def tool_provider_type(self) -> ToolProviderType: - return UserToolProvider.ProviderType.BUILTIN + return ToolProviderType.BUILT_IN def get_max_tokens(self) -> int: """ @@ -52,7 +51,7 @@ class BuiltinTool(Tool): :param model_config: the model config :return: the max tokens """ - return ToolModelManager.get_max_llm_context_tokens( + return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id, ) @@ -63,7 +62,7 @@ class BuiltinTool(Tool): :param prompt_messages: the prompt messages :return: the tokens """ - return ToolModelManager.calculate_tokens( + return ModelInvocationUtils.calculate_tokens( tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages ) diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 6e11427d58..18cf780668 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -7,7 +7,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService -from core.rerank.rerank import RerankRunner +from core.rag.rerank.rerank import RerankRunner from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -79,7 +79,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): document_score_list = {} for item in all_documents: - if 'score' in item.metadata and item.metadata['score']: + if item.metadata.get('score'): document_score_list[item.metadata['doc_id']] = item.metadata['score'] document_context_list = [] @@ -99,9 +99,9 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): float('inf'))) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') else: - document_context_list.append(segment.content) + document_context_list.append(segment.get_sign_content()) if self.return_resource: context_list = [] resource_number = 1 diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 552174e0ba..af45fc66f2 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -87,7 +87,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): document_score_list = {} if dataset.indexing_technique != "economy": for item in documents: - if 'score' in item.metadata and item.metadata['score']: + if item.metadata.get('score'): document_score_list[item.metadata['doc_id']] = item.metadata['score'] document_context_list = [] index_node_ids = [document.metadata['doc_id'] for document in documents] @@ -105,9 +105,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): float('inf'))) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') else: - document_context_list.append(segment.content) + document_context_list.append(segment.get_sign_content()) if self.return_resource: context_list = [] resource_number = 1 diff --git a/api/core/tools/tool/model_tool.py b/api/core/tools/tool/model_tool.py deleted file mode 100644 index b87e85f89c..0000000000 --- a/api/core/tools/tool/model_tool.py +++ /dev/null @@ -1,159 +0,0 @@ -from base64 import b64encode -from enum import Enum -from typing import Any, cast - -from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import ( - PromptMessageContent, - PromptMessageContentType, - SystemPromptMessage, - UserPromptMessage, -) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage, ToolProviderType -from core.tools.tool.tool import Tool - -VISION_PROMPT = """## Image Recognition Task -### Task Description -I require a powerful vision language model for an image recognition task. The model should be capable of extracting various details from the images, including but not limited to text content, layout distribution, color distribution, main subjects, and emotional expressions. -### Specific Requirements -1. **Text Content Extraction:** Ensure that the model accurately recognizes and extracts text content from the images, regardless of text size, font, or color. -2. **Layout Distribution Analysis:** The model should analyze the layout structure of the images, capturing the relationships between various elements and providing detailed information about the image layout. -3. **Color Distribution Analysis:** Extract information about color distribution in the images, including primary colors, color combinations, and other relevant details. -4. **Main Subject Recognition:** The model should accurately identify the main subjects in the images and provide detailed descriptions of these subjects. -5. **Emotional Expression Analysis:** Analyze and describe the emotions or expressions conveyed in the images based on facial expressions, postures, and other relevant features. -### Additional Considerations -- Ensure that the extracted information is as comprehensive and accurate as possible. -- For each task, provide confidence scores or relevance scores for the model outputs to assess the reliability of the results. -- If necessary, pose specific questions for different tasks to guide the model in better understanding the images and providing relevant information.""" - -class ModelTool(Tool): - class ModelToolType(Enum): - """ - the type of the model tool - """ - VISION = 'vision' - - model_configuration: dict[str, Any] = None - tool_type: ModelToolType - - def __init__(self, model_instance: ModelInstance = None, model: str = None, - tool_type: ModelToolType = ModelToolType.VISION, - properties: dict[ModelToolPropertyKey, Any] = None, - **kwargs): - """ - init the tool - """ - kwargs['model_configuration'] = { - 'model_instance': model_instance, - 'model': model, - 'properties': properties - } - kwargs['tool_type'] = tool_type - super().__init__(**kwargs) - - """ - Model tool - """ - def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool': - """ - fork a new tool with meta data - - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool - """ - return self.__class__( - identity=self.identity.copy() if self.identity else None, - parameters=self.parameters.copy() if self.parameters else None, - description=self.description.copy() if self.description else None, - model_instance=self.model_configuration['model_instance'], - model=self.model_configuration['model'], - tool_type=self.tool_type, - runtime=Tool.Runtime(**meta) - ) - - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> None: - """ - validate the credentials for Model tool - """ - pass - - def tool_provider_type(self) -> ToolProviderType: - return ToolProviderType.BUILT_IN - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: - """ - """ - model_instance = self.model_configuration['model_instance'] - if not model_instance: - return self.create_text_message('the tool is not configured correctly') - - if self.tool_type == ModelTool.ModelToolType.VISION: - return self._invoke_llm_vision(user_id, tool_parameters) - else: - return self.create_text_message('the tool is not configured correctly') - - def _invoke_llm_vision(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: - # get image - image_parameter_name = self.model_configuration['properties'].get(ModelToolPropertyKey.IMAGE_PARAMETER_NAME, 'image_id') - image_id = tool_parameters.pop(image_parameter_name, '') - if not image_id: - image = self.get_default_image_variable() - if not image: - return self.create_text_message('Please upload an image or input image_id') - else: - image = self.get_variable(image_id) - if not image: - image = self.get_default_image_variable() - if not image: - return self.create_text_message('Please upload an image or input image_id') - - if not image: - return self.create_text_message('Please upload an image or input image_id') - - # get image - image = self.get_variable_file(image.name) - if not image: - return self.create_text_message('Failed to get image') - - # organize prompt messages - prompt_messages = [ - SystemPromptMessage( - content=VISION_PROMPT - ), - UserPromptMessage( - content=[ - PromptMessageContent( - type=PromptMessageContentType.TEXT, - data='Recognize the image and extract the information from the image.' - ), - PromptMessageContent( - type=PromptMessageContentType.IMAGE, - data=f'data:image/png;base64,{b64encode(image).decode("utf-8")}' - ) - ] - ) - ] - - llm_instance = cast(LargeLanguageModel, self.model_configuration['model_instance']) - result: LLMResult = llm_instance.invoke( - model=self.model_configuration['model'], - credentials=self.runtime.credentials, - prompt_messages=prompt_messages, - model_parameters=tool_parameters, - tools=[], - stop=[], - stream=False, - user=user_id, - ) - - if not result: - return self.create_text_message('Failed to extract information from the image') - - # get result - content = result.message.content - if not content: - return self.create_text_message('Failed to extract information from the image') - - return self.create_text_message(content) \ No newline at end of file diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 03aa0623fe..5ae54e4728 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -1,12 +1,16 @@ from abc import ABC, abstractmethod +from copy import deepcopy from enum import Enum from typing import Any, Optional, Union from pydantic import BaseModel, validator +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.file_obj import FileVar from core.tools.entities.tool_entities import ( ToolDescription, ToolIdentity, + ToolInvokeFrom, ToolInvokeMessage, ToolParameter, ToolProviderType, @@ -15,6 +19,7 @@ from core.tools.entities.tool_entities import ( ToolRuntimeVariablePool, ) from core.tools.tool_file_manager import ToolFileManager +from core.tools.utils.tool_parameter_converter import ToolParameterConverter class Tool(BaseModel, ABC): @@ -25,10 +30,7 @@ class Tool(BaseModel, ABC): @validator('parameters', pre=True, always=True) def set_parameters(cls, v, values): - if not v: - return [] - - return v + return v or [] class Runtime(BaseModel): """ @@ -41,6 +43,8 @@ class Tool(BaseModel, ABC): tenant_id: str = None tool_id: str = None + invoke_from: InvokeFrom = None + tool_invoke_from: ToolInvokeFrom = None credentials: dict[str, Any] = None runtime_parameters: dict[str, Any] = None @@ -53,7 +57,7 @@ class Tool(BaseModel, ABC): class VARIABLE_KEY(Enum): IMAGE = 'image' - def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': """ fork a new tool with meta data @@ -64,7 +68,7 @@ class Tool(BaseModel, ABC): identity=self.identity.copy() if self.identity else None, parameters=self.parameters.copy() if self.parameters else None, description=self.description.copy() if self.description else None, - runtime=Tool.Runtime(**meta), + runtime=Tool.Runtime(**runtime), ) @abstractmethod @@ -208,17 +212,17 @@ class Tool(BaseModel, ABC): if response.type == ToolInvokeMessage.MessageType.TEXT: result += response.message elif response.type == ToolInvokeMessage.MessageType.LINK: - result += f"result link: {response.message}. please tell user to check it." + result += f"result link: {response.message}. please tell user to check it. \n" elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ response.type == ToolInvokeMessage.MessageType.IMAGE: - result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." + result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now. \n" elif response.type == ToolInvokeMessage.MessageType.BLOB: if len(response.message) > 114: result += str(response.message[:114]) + '...' else: result += str(response.message) else: - result += f"tool response: {response.message}." + result += f"tool response: {response.message}. \n" return result @@ -226,46 +230,13 @@ class Tool(BaseModel, ABC): """ Transform tool parameters type """ + # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials + result = deepcopy(tool_parameters) for parameter in self.parameters: if parameter.name in tool_parameters: - if parameter.type in [ - ToolParameter.ToolParameterType.SECRET_INPUT, - ToolParameter.ToolParameterType.STRING, - ToolParameter.ToolParameterType.SELECT, - ] and not isinstance(tool_parameters[parameter.name], str): - if tool_parameters[parameter.name] is None: - tool_parameters[parameter.name] = '' - else: - tool_parameters[parameter.name] = str(tool_parameters[parameter.name]) - elif parameter.type == ToolParameter.ToolParameterType.NUMBER \ - and not isinstance(tool_parameters[parameter.name], int | float): - if isinstance(tool_parameters[parameter.name], str): - try: - tool_parameters[parameter.name] = int(tool_parameters[parameter.name]) - except ValueError: - tool_parameters[parameter.name] = float(tool_parameters[parameter.name]) - elif isinstance(tool_parameters[parameter.name], bool): - tool_parameters[parameter.name] = int(tool_parameters[parameter.name]) - elif tool_parameters[parameter.name] is None: - tool_parameters[parameter.name] = 0 - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter.name], bool): - # check if it is a string - if isinstance(tool_parameters[parameter.name], str): - # check true false - if tool_parameters[parameter.name].lower() in ['true', 'false']: - tool_parameters[parameter.name] = tool_parameters[parameter.name].lower() == 'true' - # check 1 0 - elif tool_parameters[parameter.name] in ['1', '0']: - tool_parameters[parameter.name] = tool_parameters[parameter.name] == '1' - else: - tool_parameters[parameter.name] = bool(tool_parameters[parameter.name]) - elif isinstance(tool_parameters[parameter.name], int | float): - tool_parameters[parameter.name] = tool_parameters[parameter.name] != 0 - else: - tool_parameters[parameter.name] = bool(tool_parameters[parameter.name]) - - return tool_parameters + result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(tool_parameters[parameter.name], parameter.type) + + return result @abstractmethod def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: @@ -324,14 +295,6 @@ class Tool(BaseModel, ABC): return parameters - def is_tool_available(self) -> bool: - """ - check if the tool is available - - :return: if the tool is available - """ - return True - def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: """ create an image message @@ -343,6 +306,14 @@ class Tool(BaseModel, ABC): message=image, save_as=save_as) + def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage: + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, + message='', + meta={ + 'file_var': file_var + }, + save_as='') + def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: """ create a link message @@ -361,10 +332,11 @@ class Tool(BaseModel, ABC): :param text: the text :return: the text message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, - message=text, - save_as=save_as - ) + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=text, + save_as=save_as + ) def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: """ @@ -373,7 +345,8 @@ class Tool(BaseModel, ABC): :param blob: the blob :return: the blob message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, - message=blob, meta=meta, - save_as=save_as - ) \ No newline at end of file + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=blob, meta=meta, + save_as=save_as + ) diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py new file mode 100644 index 0000000000..122b663f94 --- /dev/null +++ b/api/core/tools/tool/workflow_tool.py @@ -0,0 +1,200 @@ +import json +import logging +from copy import deepcopy +from typing import Any, Union + +from core.file.file_obj import FileTransferMethod, FileVar +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType +from core.tools.tool.tool import Tool +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser +from models.workflow import Workflow + +logger = logging.getLogger(__name__) + +class WorkflowTool(Tool): + workflow_app_id: str + version: str + workflow_entities: dict[str, Any] + workflow_call_depth: int + + label: str + + """ + Workflow tool. + """ + def tool_provider_type(self) -> ToolProviderType: + """ + get the tool provider type + + :return: the tool provider type + """ + return ToolProviderType.WORKFLOW + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ + -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke the tool + """ + app = self._get_app(app_id=self.workflow_app_id) + workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version) + + # transform the tool parameters + tool_parameters, files = self._transform_args(tool_parameters) + + from core.app.apps.workflow.app_generator import WorkflowAppGenerator + generator = WorkflowAppGenerator() + result = generator.generate( + app_model=app, + workflow=workflow, + user=self._get_user(user_id), + args={ + 'inputs': tool_parameters, + 'files': files + }, + invoke_from=self.runtime.invoke_from, + stream=False, + call_depth=self.workflow_call_depth + 1, + ) + + data = result.get('data', {}) + + if data.get('error'): + raise Exception(data.get('error')) + + result = [] + + outputs = data.get('outputs', {}) + outputs, files = self._extract_files(outputs) + for file in files: + result.append(self.create_file_var_message(file)) + + result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) + + return result + + def _get_user(self, user_id: str) -> Union[EndUser, Account]: + """ + get the user by user id + """ + + user = db.session.query(EndUser).filter(EndUser.id == user_id).first() + if not user: + user = db.session.query(Account).filter(Account.id == user_id).first() + + if not user: + raise ValueError('user not found') + + return user + + def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool': + """ + fork a new tool with meta data + + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool + """ + return self.__class__( + identity=deepcopy(self.identity), + parameters=deepcopy(self.parameters), + description=deepcopy(self.description), + runtime=Tool.Runtime(**runtime), + workflow_app_id=self.workflow_app_id, + workflow_entities=self.workflow_entities, + workflow_call_depth=self.workflow_call_depth, + version=self.version, + label=self.label + ) + + def _get_workflow(self, app_id: str, version: str) -> Workflow: + """ + get the workflow by app id and version + """ + if not version: + workflow = db.session.query(Workflow).filter( + Workflow.app_id == app_id, + Workflow.version != 'draft' + ).order_by(Workflow.created_at.desc()).first() + else: + workflow = db.session.query(Workflow).filter( + Workflow.app_id == app_id, + Workflow.version == version + ).first() + + if not workflow: + raise ValueError('workflow not found or not published') + + return workflow + + def _get_app(self, app_id: str) -> App: + """ + get the app by app id + """ + app = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError('app not found') + + return app + + def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: + """ + transform the tool parameters + + :param tool_parameters: the tool parameters + :return: tool_parameters, files + """ + parameter_rules = self.get_all_runtime_parameters() + parameters_result = {} + files = [] + for parameter in parameter_rules: + if parameter.type == ToolParameter.ToolParameterType.FILE: + file = tool_parameters.get(parameter.name) + if file: + try: + file_var_list = [FileVar(**f) for f in file] + for file_var in file_var_list: + file_dict = { + 'transfer_method': file_var.transfer_method.value, + 'type': file_var.type.value, + } + if file_var.transfer_method == FileTransferMethod.TOOL_FILE: + file_dict['tool_file_id'] = file_var.related_id + elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE: + file_dict['upload_file_id'] = file_var.related_id + elif file_var.transfer_method == FileTransferMethod.REMOTE_URL: + file_dict['url'] = file_var.preview_url + + files.append(file_dict) + except Exception as e: + logger.exception(e) + else: + parameters_result[parameter.name] = tool_parameters.get(parameter.name) + + return parameters_result, files + + def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: + """ + extract files from the result + + :param result: the result + :return: the result, files + """ + files = [] + result = {} + for key, value in outputs.items(): + if isinstance(value, list): + has_file = False + for item in value: + if isinstance(item, dict) and item.get('__variant') == 'FileVar': + try: + files.append(FileVar(**item)) + has_file = True + except Exception as e: + pass + if has_file: + continue + + result[key] = value + + return result, files \ No newline at end of file diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index f96d7940bd..16fe9051e3 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,7 +1,10 @@ from copy import deepcopy from datetime import datetime, timezone +from mimetypes import guess_type from typing import Union +from yarl import URL + from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler @@ -17,6 +20,7 @@ from core.tools.errors import ( ToolProviderNotFoundError, ) from core.tools.tool.tool import Tool +from core.tools.tool.workflow_tool import WorkflowTool from core.tools.utils.message_transformer import ToolFileMessageTransformer from extensions.ext_database import db from models.model import Message, MessageFile @@ -115,7 +119,8 @@ class ToolEngine: @staticmethod def workflow_invoke(tool: Tool, tool_parameters: dict, user_id: str, workflow_id: str, - workflow_tool_callback: DifyWorkflowCallbackHandler) \ + workflow_tool_callback: DifyWorkflowCallbackHandler, + workflow_call_depth: int) \ -> list[ToolInvokeMessage]: """ Workflow invokes the tool with the given arguments. @@ -127,6 +132,9 @@ class ToolEngine: tool_inputs=tool_parameters ) + if isinstance(tool, WorkflowTool): + tool.workflow_call_depth = workflow_call_depth + 1 + response = tool.invoke(user_id, tool_parameters) # hit the callback handler @@ -195,8 +203,24 @@ class ToolEngine: for response in tool_response: if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ response.type == ToolInvokeMessage.MessageType.IMAGE: + mimetype = None + if response.meta.get('mime_type'): + mimetype = response.meta.get('mime_type') + else: + try: + url = URL(response.message) + extension = url.suffix + guess_type_result, _ = guess_type(f'a{extension}') + if guess_type_result: + mimetype = guess_type_result + except Exception: + pass + + if not mimetype: + mimetype = 'image/jpeg' + result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream'), + mimetype=response.meta.get('mime_type', 'image/jpeg'), url=response.message, save_as=response.save_as, )) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index e21a2efedd..207f009eed 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -53,7 +53,7 @@ class ToolFileManager: return False current_time = int(time.time()) - return current_time - int(timestamp) <= 300 # expired after 5 minutes + return current_time - int(timestamp) <= current_app.config.get('FILES_ACCESS_TIMEOUT') @staticmethod def create_file_by_raw(user_id: str, tenant_id: str, @@ -65,7 +65,7 @@ class ToolFileManager: """ extension = guess_extension(mimetype) or '.bin' unique_name = uuid4().hex - filename = f"/tools/{tenant_id}/{unique_name}{extension}" + filename = f"tools/{tenant_id}/{unique_name}{extension}" storage.save(filename, file_binary) tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, @@ -90,7 +90,7 @@ class ToolFileManager: mimetype = guess_type(file_url)[0] or 'octet/stream' extension = guess_extension(mimetype) or '.bin' unique_name = uuid4().hex - filename = f"/tools/{tenant_id}/{unique_name}{extension}" + filename = f"tools/{tenant_id}/{unique_name}{extension}" storage.save(filename, blob) tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py new file mode 100644 index 0000000000..97788a7a07 --- /dev/null +++ b/api/core/tools/tool_label_manager.py @@ -0,0 +1,96 @@ +from core.tools.entities.values import default_tool_label_name_list +from core.tools.provider.api_tool_provider import ApiToolProviderController +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController +from extensions.ext_database import db +from models.tools import ToolLabelBinding + + +class ToolLabelManager: + @classmethod + def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]: + """ + Filter tool labels + """ + tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] + return list(set(tool_labels)) + + @classmethod + def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): + """ + Update tool labels + """ + labels = cls.filter_tool_labels(labels) + + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + provider_id = controller.provider_id + else: + raise ValueError('Unsupported tool type') + + # delete old labels + db.session.query(ToolLabelBinding).filter( + ToolLabelBinding.tool_id == provider_id + ).delete() + + # insert new labels + for label in labels: + db.session.add(ToolLabelBinding( + tool_id=provider_id, + tool_type=controller.provider_type.value, + label_name=label, + )) + + db.session.commit() + + @classmethod + def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: + """ + Get tool labels + """ + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + provider_id = controller.provider_id + elif isinstance(controller, BuiltinToolProviderController): + return controller.tool_labels + else: + raise ValueError('Unsupported tool type') + + labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter( + ToolLabelBinding.tool_id == provider_id, + ToolLabelBinding.tool_type == controller.provider_type.value, + ).all() + + return [label.label_name for label in labels] + + @classmethod + def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: + """ + Get tools labels + + :param tool_providers: list of tool providers + + :return: dict of tool labels + :key: tool id + :value: list of tool labels + """ + if not tool_providers: + return {} + + for controller in tool_providers: + if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + raise ValueError('Unsupported tool type') + + provider_ids = [controller.provider_id for controller in tool_providers] + + labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter( + ToolLabelBinding.tool_id.in_(provider_ids) + ).all() + + tool_labels = { + label.tool_id: [] for label in labels + } + + for label in labels: + tool_labels[label.tool_id].append(label.label_name) + + return tool_labels \ No newline at end of file diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index a29bdfcd11..a0ca9f692a 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,32 +9,33 @@ from typing import Any, Union from flask import current_app from core.agent.entities import AgentToolEntity +from core.app.entities.app_invoke_entities import InvokeFrom +from core.helper.module_import_helper import load_single_subclass_from_source from core.model_runtime.utils.encoders import jsonable_encoder -from core.provider_manager import ProviderManager -from core.tools import * +from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ToolInvokeFrom, ToolParameter, ) -from core.tools.entities.user_entities import UserToolProvider from core.tools.errors import ToolProviderNotFoundError -from core.tools.provider.api_tool_provider import ApiBasedToolProviderController +from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController -from core.tools.provider.model_tool_provider import ModelToolProviderController from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool +from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ( ToolConfigurationManager, ToolParameterConfigurationManager, ) -from core.utils.module_import_helper import load_single_subclass_from_source +from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db -from models.tools import ApiToolProvider, BuiltinToolProvider -from services.tools_transform_service import ToolTransformService +from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider +from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -101,7 +102,12 @@ class ToolManager: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') @classmethod - def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \ + def get_tool_runtime(cls, provider_type: str, + provider_id: str, + tool_name: str, + tenant_id: str, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ -> Union[BuiltinTool, ApiTool]: """ get the tool runtime @@ -113,64 +119,76 @@ class ToolManager: :return: the tool """ if provider_type == 'builtin': - builtin_tool = cls.get_builtin_tool(provider_name, tool_name) + builtin_tool = cls.get_builtin_tool(provider_id, tool_name) # check if the builtin tool need credentials - provider_controller = cls.get_builtin_provider(provider_name) + provider_controller = cls.get_builtin_provider(provider_id) if not provider_controller.need_credentials: - return builtin_tool.fork_tool_runtime(meta={ + return builtin_tool.fork_tool_runtime(runtime={ 'tenant_id': tenant_id, 'credentials': {}, + 'invoke_from': invoke_from, + 'tool_invoke_from': tool_invoke_from, }) # get credentials builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, + BuiltinToolProvider.provider == provider_id, ).first() if builtin_provider is None: - raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found') + raise ToolProviderNotFoundError(f'builtin provider {provider_id} not found') # decrypt the credentials credentials = builtin_provider.credentials - controller = cls.get_builtin_provider(provider_name) + controller = cls.get_builtin_provider(provider_id) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return builtin_tool.fork_tool_runtime(meta={ + return builtin_tool.fork_tool_runtime(runtime={ 'tenant_id': tenant_id, 'credentials': decrypted_credentials, - 'runtime_parameters': {} + 'runtime_parameters': {}, + 'invoke_from': invoke_from, + 'tool_invoke_from': tool_invoke_from, }) elif provider_type == 'api': if tenant_id is None: raise ValueError('tenant id is required for api provider') - api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name) + api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) # decrypt the credentials tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return api_provider.get_tool(tool_name).fork_tool_runtime(meta={ + return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ 'tenant_id': tenant_id, 'credentials': decrypted_credentials, + 'invoke_from': invoke_from, + 'tool_invoke_from': tool_invoke_from, }) - elif provider_type == 'model': - if tenant_id is None: - raise ValueError('tenant id is required for model provider') - # get model provider - model_provider = cls.get_model_provider(tenant_id, provider_name) + elif provider_type == 'workflow': + workflow_provider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == provider_id + ).first() - # get tool - model_tool = model_provider.get_tool(tool_name) + if workflow_provider is None: + raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') - return model_tool.fork_tool_runtime(meta={ + controller = ToolTransformService.workflow_provider_to_controller( + db_provider=workflow_provider + ) + + return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={ 'tenant_id': tenant_id, - 'credentials': model_tool.model_configuration['model_instance'].credentials + 'credentials': {}, + 'invoke_from': invoke_from, + 'tool_invoke_from': tool_invoke_from, }) elif provider_type == 'app': raise NotImplementedError('app provider not implemented') @@ -196,44 +214,28 @@ class ToolManager: raise ValueError( f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") - # convert tool parameter config to correct type - try: - if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER: - # check if tool parameter is integer - if isinstance(parameter_value, int): - parameter_value = parameter_value - elif isinstance(parameter_value, float): - parameter_value = parameter_value - elif isinstance(parameter_value, str): - if '.' in parameter_value: - parameter_value = float(parameter_value) - else: - parameter_value = int(parameter_value) - elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN: - parameter_value = bool(parameter_value) - elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, - ToolParameter.ToolParameterType.STRING]: - parameter_value = str(parameter_value) - elif parameter_rule.type == ToolParameter.ToolParameterType: - parameter_value = str(parameter_value) - except Exception as e: - raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type") - - return parameter_value + return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type) @classmethod - def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool: + def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: """ get the agent tool runtime """ tool_entity = cls.get_tool_runtime( - provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, + provider_type=agent_tool.provider_type, + provider_id=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, + invoke_from=invoke_from, + tool_invoke_from=ToolInvokeFrom.AGENT ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: + # check file types + if parameter.type == ToolParameter.ToolParameterType.FILE: + raise ValueError(f"file type parameter {parameter.name} not supported in agent") + if parameter.form == ToolParameter.ToolParameterForm.FORM: # save tool parameter to tool entity memory value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters) @@ -253,15 +255,17 @@ class ToolManager: return tool_entity @classmethod - def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity): + def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: """ get the workflow tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=workflow_tool.provider_type, - provider_name=workflow_tool.provider_id, + provider_id=workflow_tool.provider_id, tool_name=workflow_tool.tool_name, tenant_id=tenant_id, + invoke_from=invoke_from, + tool_invoke_from=ToolInvokeFrom.WORKFLOW ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -367,49 +371,6 @@ class ToolManager: cls._builtin_providers = {} cls._builtin_providers_loaded = False - # @classmethod - # def list_model_providers(cls, tenant_id: str = None) -> list[ModelToolProviderController]: - # """ - # list all the model providers - - # :return: the list of the model providers - # """ - # tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff' - # # get configurations - # model_configurations = ModelToolConfigurationManager.get_all_configuration() - # # get all providers - # provider_manager = ProviderManager() - # configurations = provider_manager.get_configurations(tenant_id).values() - # # get model providers - # model_providers: list[ModelToolProviderController] = [] - # for configuration in configurations: - # # all the model tool should be configurated - # if configuration.provider.provider not in model_configurations: - # continue - # if not ModelToolProviderController.is_configuration_valid(configuration): - # continue - # model_providers.append(ModelToolProviderController.from_db(configuration)) - - # return model_providers - - @classmethod - def get_model_provider(cls, tenant_id: str, provider_name: str) -> ModelToolProviderController: - """ - get the model provider - - :param provider_name: the name of the provider - - :return: the provider - """ - # get configurations - provider_manager = ProviderManager() - configurations = provider_manager.get_configurations(tenant_id) - configuration = configurations.get(provider_name) - if configuration is None: - raise ToolProviderNotFoundError(f'model provider {provider_name} not found') - - return ModelToolProviderController.from_db(configuration) - @classmethod def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: """ @@ -419,7 +380,6 @@ class ToolManager: :return: the label of the tool """ - cls._builtin_tools_labels if len(cls._builtin_tools_labels) == 0: # init the builtin providers cls.load_builtin_providers_cache() @@ -430,60 +390,91 @@ class ToolManager: return cls._builtin_tools_labels[tool_name] @classmethod - def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: + def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral) -> list[UserToolProvider]: result_providers: dict[str, UserToolProvider] = {} - # get builtin providers - builtin_providers = cls.list_builtin_providers() - - # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ - filter(BuiltinToolProvider.tenant_id == tenant_id).all() + filters = [] + if not typ: + filters.extend(['builtin', 'api', 'workflow']) + else: + filters.append(typ) - find_db_builtin_provider = lambda provider: next( - (x for x in db_builtin_providers if x.provider == provider), - None - ) + if 'builtin' in filters: - # append builtin providers - for provider in builtin_providers: - user_provider = ToolTransformService.builtin_provider_to_user_provider( - provider_controller=provider, - db_provider=find_db_builtin_provider(provider.identity.name), - decrypt_credentials=False + # get builtin providers + builtin_providers = cls.list_builtin_providers() + + # get db builtin providers + db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ + filter(BuiltinToolProvider.tenant_id == tenant_id).all() + + find_db_builtin_provider = lambda provider: next( + (x for x in db_builtin_providers if x.provider == provider), + None ) - result_providers[provider.identity.name] = user_provider + # append builtin providers + for provider in builtin_providers: + user_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider, + db_provider=find_db_builtin_provider(provider.identity.name), + decrypt_credentials=False + ) - # # get model tool providers - # model_providers = cls.list_model_providers(tenant_id=tenant_id) - # # append model providers - # for provider in model_providers: - # user_provider = ToolTransformService.model_provider_to_user_provider( - # db_provider=provider, - # ) - # result_providers[f'model_provider.{provider.identity.name}'] = user_provider + result_providers[provider.identity.name] = user_provider # get db api providers - db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ - filter(ApiToolProvider.tenant_id == tenant_id).all() - for db_api_provider in db_api_providers: - provider_controller = ToolTransformService.api_provider_to_controller( - db_provider=db_api_provider, - ) - user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=provider_controller, - db_provider=db_api_provider, - decrypt_credentials=False - ) - result_providers[db_api_provider.name] = user_provider + if 'api' in filters: + db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ + filter(ApiToolProvider.tenant_id == tenant_id).all() + + api_provider_controllers = [{ + 'provider': provider, + 'controller': ToolTransformService.api_provider_to_controller(provider) + } for provider in db_api_providers] + + # get labels + labels = ToolLabelManager.get_tools_labels([x['controller'] for x in api_provider_controllers]) + + for api_provider_controller in api_provider_controllers: + user_provider = ToolTransformService.api_provider_to_user_provider( + provider_controller=api_provider_controller['controller'], + db_provider=api_provider_controller['provider'], + decrypt_credentials=False, + labels=labels.get(api_provider_controller['controller'].provider_id, []) + ) + result_providers[f'api_provider.{user_provider.name}'] = user_provider + + if 'workflow' in filters: + # get workflow providers + workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \ + filter(WorkflowToolProvider.tenant_id == tenant_id).all() + + workflow_provider_controllers = [] + for provider in workflow_providers: + try: + workflow_provider_controllers.append( + ToolTransformService.workflow_provider_to_controller(db_provider=provider) + ) + except Exception as e: + # app has been deleted + pass + + labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers) + + for provider_controller in workflow_provider_controllers: + user_provider = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=provider_controller, + labels=labels.get(provider_controller.provider_id, []), + ) + result_providers[f'workflow_provider.{user_provider.name}'] = user_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) @classmethod def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ - ApiBasedToolProviderController, dict[str, Any]]: + ApiToolProviderController, dict[str, Any]]: """ get the api provider @@ -499,7 +490,7 @@ class ToolManager: if provider is None: raise ToolProviderNotFoundError(f'api provider {provider_id} not found') - controller = ApiBasedToolProviderController.from_db( + controller = ApiToolProviderController.from_db( provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE @@ -530,7 +521,7 @@ class ToolManager: credentials = {} # package tool provider controller - controller = ApiBasedToolProviderController.from_db( + controller = ApiToolProviderController.from_db( provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE ) # init tool configuration @@ -547,6 +538,9 @@ class ToolManager: "content": "\ud83d\ude01" } + # add tool labels + labels = ToolLabelManager.get_tool_labels(controller) + return jsonable_encoder({ 'schema_type': provider.schema_type, 'schema': provider.schema, @@ -554,7 +548,9 @@ class ToolManager: 'icon': icon, 'description': provider.description, 'credentials': masked_credentials, - 'privacy_policy': provider.privacy_policy + 'privacy_policy': provider.privacy_policy, + 'custom_disclaimer': provider.custom_disclaimer, + 'labels': labels, }) @classmethod @@ -586,6 +582,15 @@ class ToolManager: "background": "#252525", "content": "\ud83d\ude01" } + elif provider_type == 'workflow': + provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == provider_id + ).first() + if provider is None: + raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') + + return json.loads(provider.icon) else: raise ValueError(f"provider type {provider_type} not found") diff --git a/api/core/tools/utils/__init__.py b/api/core/tools/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 917f8411c4..b213879e96 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,16 +1,12 @@ -import os from copy import deepcopy -from typing import Any, Union +from typing import Any from pydantic import BaseModel -from yaml import FullLoader, load from core.helper import encrypter from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.tools.entities.tool_entities import ( - ModelToolConfiguration, - ModelToolProviderConfiguration, ToolParameter, ToolProviderCredentials, ) @@ -27,7 +23,7 @@ class ToolConfigurationManager(BaseModel): deep copy credentials """ return deepcopy(credentials) - + def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]: """ encrypt tool credentials with tenant id @@ -43,9 +39,9 @@ class ToolConfigurationManager(BaseModel): if field_name in credentials: encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name]) credentials[field_name] = encrypted - + return credentials - + def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]: """ mask tool credentials @@ -62,7 +58,7 @@ class ToolConfigurationManager(BaseModel): if len(credentials[field_name]) > 6: credentials[field_name] = \ credentials[field_name][:2] + \ - '*' * (len(credentials[field_name]) - 4) +\ + '*' * (len(credentials[field_name]) - 4) + \ credentials[field_name][-2:] else: credentials[field_name] = '*' * len(credentials[field_name]) @@ -77,7 +73,7 @@ class ToolConfigurationManager(BaseModel): """ cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', + identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', cache_type=ToolProviderCredentialsCacheType.PROVIDER ) cached_credentials = cache.get() @@ -96,11 +92,11 @@ class ToolConfigurationManager(BaseModel): cache.set(credentials) return credentials - + def delete_tool_credentials_cache(self): cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', + identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', cache_type=ToolProviderCredentialsCacheType.PROVIDER ) cache.delete() @@ -120,7 +116,7 @@ class ToolParameterConfigurationManager(BaseModel): deep copy parameters """ return deepcopy(parameters) - + def _merge_parameters(self) -> list[ToolParameter]: """ merge parameters @@ -143,7 +139,7 @@ class ToolParameterConfigurationManager(BaseModel): current_parameters.append(runtime_parameter) return current_parameters - + def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ mask tool parameters @@ -161,13 +157,13 @@ class ToolParameterConfigurationManager(BaseModel): if len(parameters[parameter.name]) > 6: parameters[parameter.name] = \ parameters[parameter.name][:2] + \ - '*' * (len(parameters[parameter.name]) - 4) +\ + '*' * (len(parameters[parameter.name]) - 4) + \ parameters[parameter.name][-2:] else: parameters[parameter.name] = '*' * len(parameters[parameter.name]) return parameters - + def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ encrypt tool parameters with tenant id @@ -184,9 +180,9 @@ class ToolParameterConfigurationManager(BaseModel): if parameter.name in parameters: encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) parameters[parameter.name] = encrypted - + return parameters - + def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ decrypt tool parameters with tenant id @@ -194,7 +190,7 @@ class ToolParameterConfigurationManager(BaseModel): return a deep copy of parameters with decrypted values """ cache = ToolParameterCache( - tenant_id=self.tenant_id, + tenant_id=self.tenant_id, provider=f'{self.provider_type}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, @@ -216,80 +212,18 @@ class ToolParameterConfigurationManager(BaseModel): parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) except: pass - + if has_secret_input: cache.set(parameters) return parameters - + def delete_tool_parameters_cache(self): cache = ToolParameterCache( - tenant_id=self.tenant_id, + tenant_id=self.tenant_id, provider=f'{self.provider_type}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, identity_id=self.identity_id ) cache.delete() - -class ModelToolConfigurationManager: - """ - Model as tool configuration - """ - _configurations: dict[str, ModelToolProviderConfiguration] = {} - _model_configurations: dict[str, ModelToolConfiguration] = {} - _inited = False - - @classmethod - def _init_configuration(cls): - """ - init configuration - """ - - absolute_path = os.path.abspath(os.path.dirname(__file__)) - model_tools_path = os.path.join(absolute_path, '..', 'model_tools') - - # get all .yaml file - files = [f for f in os.listdir(model_tools_path) if f.endswith('.yaml')] - - for file in files: - provider = file.split('.')[0] - with open(os.path.join(model_tools_path, file), encoding='utf-8') as f: - configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader)) - models = configurations.models or [] - for model in models: - model_key = f'{provider}.{model.model}' - cls._model_configurations[model_key] = model - - cls._configurations[provider] = configurations - cls._inited = True - - @classmethod - def get_configuration(cls, provider: str) -> Union[ModelToolProviderConfiguration, None]: - """ - get configuration by provider - """ - if not cls._inited: - cls._init_configuration() - return cls._configurations.get(provider, None) - - @classmethod - def get_all_configuration(cls) -> dict[str, ModelToolProviderConfiguration]: - """ - get all configurations - """ - if not cls._inited: - cls._init_configuration() - return cls._configurations - - @classmethod - def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolConfiguration, None]: - """ - get model configuration - """ - key = f'{provider}.{model}' - - if not cls._inited: - cls._init_configuration() - - return cls._model_configurations.get(key, None) \ No newline at end of file diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 3f456b4eb6..ef9e5b67ae 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,14 +1,15 @@ import logging from mimetypes import guess_extension +from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) class ToolFileMessageTransformer: - @staticmethod - def transform_tool_invoke_messages(messages: list[ToolInvokeMessage], + @classmethod + def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str) -> list[ToolInvokeMessage]: @@ -62,7 +63,7 @@ class ToolFileMessageTransformer: mimetype=mimetype ) - url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}' + url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) # check if file is image if 'image' in mimetype: @@ -79,7 +80,30 @@ class ToolFileMessageTransformer: save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, )) + elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: + file_var: FileVar = message.meta.get('file_var') + if file_var: + if file_var.transfer_method == FileTransferMethod.TOOL_FILE: + url = cls.get_tool_file_url(file_var.related_id, file_var.extension) + if file_var.type == FileType.IMAGE: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) else: result.append(message) - return result \ No newline at end of file + return result + + @classmethod + def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: + return f'/files/tools/{tool_file_id}{extension or ".bin"}' \ No newline at end of file diff --git a/api/core/tools/model/tool_model_manager.py b/api/core/tools/utils/model_invocation_utils.py similarity index 88% rename from api/core/tools/model/tool_model_manager.py rename to api/core/tools/utils/model_invocation_utils.py index e97d78d699..9e8ef47823 100644 --- a/api/core/tools/model/tool_model_manager.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -20,12 +20,14 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.model.errors import InvokeModelError from extensions.ext_database import db from models.tools import ToolModelInvoke -class ToolModelManager: +class InvokeModelError(Exception): + pass + +class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, @@ -71,10 +73,8 @@ class ToolModelManager: if not model_instance: raise InvokeModelError('Model not found') - llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - # get tokens - tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages) + tokens = model_instance.get_llm_num_tokens(prompt_messages) return tokens @@ -106,13 +106,8 @@ class ToolModelManager: tenant_id=tenant_id, model_type=ModelType.LLM, ) - llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - - # get model credentials - model_credentials = model_instance.credentials - # get prompt tokens - prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) model_parameters = { 'temperature': 0.8, @@ -142,9 +137,7 @@ class ToolModelManager: db.session.commit() try: - response: LLMResult = llm_model.invoke( - model=model_instance.model, - credentials=model_credentials, + response: LLMResult = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=model_parameters, tools=[], stop=[], stream=False, user=user_id, callbacks=[] @@ -174,4 +167,4 @@ class ToolModelManager: db.session.commit() - return response \ No newline at end of file + return response diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index a96d8a6b7c..40ae6c66d5 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -9,14 +9,14 @@ from requests import get from yaml import YAMLError, safe_load from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError class ApiBasedToolSchemaParser: @staticmethod - def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: + def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -145,7 +145,7 @@ class ApiBasedToolSchemaParser: interface['operation']['operationId'] = f'{path}_{interface["method"]}' - bundles.append(ApiBasedToolBundle( + bundles.append(ApiToolBundle( server_url=server_url + interface['path'], method=interface['method'], summary=interface['operation']['description'] if 'description' in interface['operation'] else @@ -176,7 +176,7 @@ class ApiBasedToolSchemaParser: return ToolParameter.ToolParameterType.STRING @staticmethod - def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: + def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: """ parse openapi yaml to tool bundle @@ -258,7 +258,7 @@ class ApiBasedToolSchemaParser: return openapi @staticmethod - def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: + def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: """ parse openapi plugin yaml to tool bundle @@ -290,7 +290,7 @@ class ApiBasedToolSchemaParser: return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning) @staticmethod - def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiBasedToolBundle], str]: + def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]: """ auto parse to tool bundle diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py new file mode 100644 index 0000000000..55535be930 --- /dev/null +++ b/api/core/tools/utils/tool_parameter_converter.py @@ -0,0 +1,66 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolParameter + + +class ToolParameterConverter: + @staticmethod + def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: + match parameter_type: + case ToolParameter.ToolParameterType.STRING \ + | ToolParameter.ToolParameterType.SECRET_INPUT \ + | ToolParameter.ToolParameterType.SELECT: + return 'string' + + case ToolParameter.ToolParameterType.BOOLEAN: + return 'boolean' + + case ToolParameter.ToolParameterType.NUMBER: + return 'number' + + case _: + raise ValueError(f"Unsupported parameter type {parameter_type}") + + @staticmethod + def cast_parameter_by_type(value: Any, parameter_type: str) -> Any: + # convert tool parameter config to correct type + try: + match parameter_type: + case ToolParameter.ToolParameterType.STRING \ + | ToolParameter.ToolParameterType.SECRET_INPUT \ + | ToolParameter.ToolParameterType.SELECT: + if value is None: + return '' + else: + return value if isinstance(value, str) else str(value) + + case ToolParameter.ToolParameterType.BOOLEAN: + if value is None: + return False + elif isinstance(value, str): + # Allowed YAML boolean value strings: https://yaml.org/type/bool.html + # and also '0' for False and '1' for True + match value.lower(): + case 'true' | 'yes' | 'y' | '1': + return True + case 'false' | 'no' | 'n' | '0': + return False + case _: + return bool(value) + else: + return value if isinstance(value, bool) else bool(value) + + case ToolParameter.ToolParameterType.NUMBER: + if isinstance(value, int) | isinstance(value, float): + return value + elif isinstance(value, str): + if '.' in value: + return float(value) + else: + return int(value) + + case _: + return str(value) + + except Exception: + raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.") diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 4c6fbb2780..4c69c6eddc 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -1,5 +1,6 @@ import hashlib import json +import mimetypes import os import re import site @@ -7,6 +8,7 @@ import subprocess import tempfile import unicodedata from contextlib import contextmanager +from urllib.parse import unquote import requests from bs4 import BeautifulSoup, CData, Comment, NavigableString @@ -39,22 +41,34 @@ def get_url(url: str, user_agent: str = None) -> str: } if user_agent: headers["User-Agent"] = user_agent - - supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] - response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 10)) + main_content_type = None + supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] + response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) if response.status_code != 200: return "URL returned status code {}.".format(response.status_code) # check content-type - main_content_type = response.headers.get('Content-Type').split(';')[0].strip() + content_type = response.headers.get('Content-Type') + if content_type: + main_content_type = response.headers.get('Content-Type').split(';')[0].strip() + else: + content_disposition = response.headers.get('Content-Disposition', '') + filename_match = re.search(r'filename="([^"]+)"', content_disposition) + if filename_match: + filename = unquote(filename_match.group(1)) + extension = re.search(r'\.(\w+)$', filename) + if extension: + main_content_type = mimetypes.guess_type(filename)[0] + if main_content_type not in supported_content_types: return "Unsupported content-type [{}] of URL.".format(main_content_type) if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: return ExtractProcessor.load_from_url(url, return_text=True) + response = requests.get(url, headers=headers, allow_redirects=True, timeout=(120, 300)) a = extract_using_readabilipy(response.text) if not a['plain_text'] or not a['plain_text'].strip(): @@ -118,17 +132,17 @@ def extract_using_readabilipy(html): } # Populate article fields from readability fields where present if input_json: - if "title" in input_json and input_json["title"]: + if input_json.get("title"): article_json["title"] = input_json["title"] - if "byline" in input_json and input_json["byline"]: + if input_json.get("byline"): article_json["byline"] = input_json["byline"] - if "date" in input_json and input_json["date"]: + if input_json.get("date"): article_json["date"] = input_json["date"] - if "content" in input_json and input_json["content"]: + if input_json.get("content"): article_json["content"] = input_json["content"] article_json["plain_content"] = plain_content(article_json["content"], False, False) article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) - if "textContent" in input_json and input_json["textContent"]: + if input_json.get("textContent"): article_json["plain_text"] = input_json["textContent"] article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"]) diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py new file mode 100644 index 0000000000..ff5505bbbf --- /dev/null +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -0,0 +1,48 @@ +from core.app.app_config.entities import VariableEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration + + +class WorkflowToolConfigurationUtils: + @classmethod + def check_parameter_configurations(cls, configurations: list[dict]): + """ + check parameter configurations + """ + for configuration in configurations: + if not WorkflowToolParameterConfiguration(**configuration): + raise ValueError('invalid parameter configuration') + + @classmethod + def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]: + """ + get workflow graph variables + """ + nodes = graph.get('nodes', []) + start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None) + + if not start_node: + return [] + + return [ + VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', []) + ] + + @classmethod + def check_is_synced(cls, + variables: list[VariableEntity], + tool_configurations: list[WorkflowToolParameterConfiguration]) -> None: + """ + check is synced + + raise ValueError if not synced + """ + variable_names = [variable.variable for variable in variables] + + if len(tool_configurations) != len(variables): + raise ValueError('parameter configuration mismatch, please republish the tool to update') + + for parameter in tool_configurations: + if parameter.name not in variable_names: + raise ValueError('parameter configuration mismatch, please republish the tool to update') + + return True \ No newline at end of file diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py new file mode 100644 index 0000000000..22e4d3d128 --- /dev/null +++ b/api/core/tools/utils/yaml_utils.py @@ -0,0 +1,34 @@ +import logging +import os + +import yaml +from yaml import YAMLError + + +def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict: + """ + Safe loading a YAML file to a dict + :param file_path: the path of the YAML file + :param ignore_error: + if True, return empty dict if error occurs and the error will be logged in warning level + if False, raise error if error occurs + :return: a dict of the YAML content + """ + try: + if not file_path or not os.path.exists(file_path): + raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found') + + with open(file_path, encoding='utf-8') as file: + try: + return yaml.safe_load(file) + except Exception as e: + raise YAMLError(f'Failed to load YAML file {file_path}: {e}') + except FileNotFoundError as e: + logging.debug(f'Failed to load YAML file {file_path}: {e}') + return {} + except Exception as e: + if ignore_error: + logging.warning(f'Failed to load YAML file {file_path}: {e}') + return {} + else: + raise e diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index dd5a30f611..3b0d51d868 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Optional from core.app.entities.queue_entities import AppQueueEvent from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -71,6 +71,42 @@ class BaseWorkflowCallback(ABC): Publish text chunk """ raise NotImplementedError + + @abstractmethod + def on_workflow_iteration_started(self, + node_id: str, + node_type: NodeType, + node_run_index: int = 1, + node_data: Optional[BaseNodeData] = None, + inputs: dict = None, + predecessor_node_id: Optional[str] = None, + metadata: Optional[dict] = None) -> None: + """ + Publish iteration started + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_iteration_next(self, node_id: str, + node_type: NodeType, + index: int, + node_run_index: int, + output: Optional[Any], + ) -> None: + """ + Publish iteration next + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_iteration_completed(self, node_id: str, + node_type: NodeType, + node_run_index: int, + outputs: dict) -> None: + """ + Publish iteration completed + """ + raise NotImplementedError @abstractmethod def on_event(self, event: AppQueueEvent) -> None: diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index fc6ee231ff..6bf0c11c7d 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -7,3 +7,16 @@ from pydantic import BaseModel class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + +class BaseIterationNodeData(BaseNodeData): + start_node_id: str + +class BaseIterationState(BaseModel): + iteration_node_id: str + index: int + inputs: dict + + class MetaData(BaseModel): + pass + + metadata: MetaData \ No newline at end of file diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 7eb9488792..ae86463407 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -21,7 +21,11 @@ class NodeType(Enum): QUESTION_CLASSIFIER = 'question-classifier' HTTP_REQUEST = 'http-request' TOOL = 'tool' + VARIABLE_AGGREGATOR = 'variable-aggregator' VARIABLE_ASSIGNER = 'variable-assigner' + LOOP = 'loop' + ITERATION = 'iteration' + PARAMETER_EXTRACTOR = 'parameter-extractor' @classmethod def value_of(cls, value: str) -> 'NodeType': @@ -68,6 +72,8 @@ class NodeRunMetadataKey(Enum): TOTAL_PRICE = 'total_price' CURRENCY = 'currency' TOOL_INFO = 'tool_info' + ITERATION_ID = 'iteration_id' + ITERATION_INDEX = 'iteration_index' class NodeRunResult(BaseModel): diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 690bdddaf6..c04770616c 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -90,3 +90,12 @@ class VariablePool: raise ValueError(f'Invalid value type: {target_value_type.value}') return value + + def clear_node_variables(self, node_id: str) -> None: + """ + Clear node variables + :param node_id: node id + :return: + """ + if node_id in self.variables_mapping: + self.variables_mapping.pop(node_id) \ No newline at end of file diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index e1c5eb6752..9b35b8df8a 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -1,5 +1,9 @@ from typing import Optional +from pydantic import BaseModel + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.base_node_data_entities import BaseIterationState from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode, UserFrom @@ -22,6 +26,9 @@ class WorkflowRunState: workflow_type: WorkflowType user_id: str user_from: UserFrom + invoke_from: InvokeFrom + + workflow_call_depth: int start_at: float variable_pool: VariablePool @@ -30,20 +37,37 @@ class WorkflowRunState: workflow_nodes_and_results: list[WorkflowNodeAndResult] + class NodeRun(BaseModel): + node_id: str + iteration_node_id: str + + workflow_node_runs: list[NodeRun] + workflow_node_steps: int + + current_iteration_state: Optional[BaseIterationState] + def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool, user_id: str, - user_from: UserFrom): + user_from: UserFrom, + invoke_from: InvokeFrom, + workflow_call_depth: int): self.workflow_id = workflow.id self.tenant_id = workflow.tenant_id self.app_id = workflow.app_id self.workflow_type = WorkflowType.value_of(workflow.type) self.user_id = user_id self.user_from = user_from + self.invoke_from = invoke_from + self.workflow_call_depth = workflow_call_depth self.start_at = start_at self.variable_pool = variable_pool self.total_tokens = 0 self.workflow_nodes_and_results = [] + + self.current_iteration_state = None + self.workflow_node_steps = 1 + self.workflow_node_runs = [] \ No newline at end of file diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 7cc9c6ee3d..fa7d6424f1 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -2,8 +2,9 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Optional +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool @@ -37,6 +38,9 @@ class BaseNode(ABC): workflow_id: str user_id: str user_from: UserFrom + invoke_from: InvokeFrom + + workflow_call_depth: int node_id: str node_data: BaseNodeData @@ -49,13 +53,17 @@ class BaseNode(ABC): workflow_id: str, user_id: str, user_from: UserFrom, + invoke_from: InvokeFrom, config: dict, - callbacks: list[BaseWorkflowCallback] = None) -> None: + callbacks: list[BaseWorkflowCallback] = None, + workflow_call_depth: int = 0) -> None: self.tenant_id = tenant_id self.app_id = app_id self.workflow_id = workflow_id self.user_id = user_id self.user_from = user_from + self.invoke_from = invoke_from + self.workflow_call_depth = workflow_call_depth self.node_id = config.get("id") if not self.node_id: @@ -140,3 +148,38 @@ class BaseNode(ABC): :return: """ return self._node_type + +class BaseIterationNode(BaseNode): + @abstractmethod + def _run(self, variable_pool: VariablePool) -> BaseIterationState: + """ + Run node + :param variable_pool: variable pool + :return: + """ + raise NotImplementedError + + def run(self, variable_pool: VariablePool) -> BaseIterationState: + """ + Run node entry + :param variable_pool: variable pool + :return: + """ + return self._run(variable_pool=variable_pool) + + def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: + """ + Get next iteration start node id based on the graph. + :param graph: graph + :return: next node id + """ + return self._get_next_iteration(variable_pool, state) + + @abstractmethod + def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: + """ + Get next iteration start node id based on the graph. + :param graph: graph + :return: next node id + """ + raise NotImplementedError diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 2c1529f492..610a23e704 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,7 +1,10 @@ import os from typing import Optional, Union, cast -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor +from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage +from core.helper.code_executor.code_node_provider import CodeNodeProvider +from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider +from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -17,16 +20,6 @@ MAX_STRING_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_STRING_ARRAY_LENGTH', '30 MAX_OBJECT_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_OBJECT_ARRAY_LENGTH', '30')) MAX_NUMBER_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_NUMBER_ARRAY_LENGTH', '1000')) -JAVASCRIPT_DEFAULT_CODE = """function main({arg1, arg2}) { - return { - result: arg1 + arg2 - } -}""" - -PYTHON_DEFAULT_CODE = """def main(arg1: int, arg2: int) -> dict: - return { - "result": arg1 + arg2, - }""" class CodeNode(BaseNode): _node_data_cls = CodeNodeData @@ -39,54 +32,15 @@ class CodeNode(BaseNode): :param filters: filter by node config parameters. :return: """ - if filters and filters.get("code_language") == "javascript": - return { - "type": "code", - "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - }, - { - "variable": "arg2", - "value_selector": [] - } - ], - "code_language": "javascript", - "code": JAVASCRIPT_DEFAULT_CODE, - "outputs": { - "result": { - "type": "string", - "children": None - } - } - } - } + code_language = CodeLanguage.PYTHON3 + if filters: + code_language = (filters.get("code_language", CodeLanguage.PYTHON3)) - return { - "type": "code", - "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - }, - { - "variable": "arg2", - "value_selector": [] - } - ], - "code_language": "python3", - "code": PYTHON_DEFAULT_CODE, - "outputs": { - "result": { - "type": "string", - "children": None - } - } - } - } + providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] + code_provider: type[CodeNodeProvider] = next(p for p in providers + if p.is_accept_language(code_language)) + + return code_provider.get_default_config() def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ @@ -115,7 +69,8 @@ class CodeNode(BaseNode): result = CodeExecutor.execute_workflow_code_template( language=code_language, code=code, - inputs=variables + inputs=variables, + dependencies=node_data.dependencies ) # Transform result diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 555bb3918e..03044268ab 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -2,6 +2,8 @@ from typing import Literal, Optional from pydantic import BaseModel +from core.helper.code_executor.code_executor import CodeLanguage +from core.helper.code_executor.entities import CodeDependency from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -15,6 +17,7 @@ class CodeNodeData(BaseNodeData): children: Optional[dict[str, 'Output']] variables: list[VariableSelector] - code_language: Literal['python3', 'javascript'] + code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str - outputs: dict[str, Output] \ No newline at end of file + outputs: dict[str, Output] + dependencies: Optional[list[CodeDependency]] = None \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index d88ad999b7..4a81a4176d 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,9 +1,13 @@ +import os from typing import Literal, Optional, Union from pydantic import BaseModel, validator from core.workflow.entities.base_node_data_entities import BaseNodeData +MAX_CONNECT_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_CONNECT_TIMEOUT', '300')) +MAX_READ_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_READ_TIMEOUT', '600')) +MAX_WRITE_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_WRITE_TIMEOUT', '600')) class HttpRequestNodeData(BaseNodeData): """ @@ -36,9 +40,9 @@ class HttpRequestNodeData(BaseNodeData): data: Union[None, str] class Timeout(BaseModel): - connect: int - read: int - write: int + connect: Optional[int] = MAX_CONNECT_TIMEOUT + read: Optional[int] = MAX_READ_TIMEOUT + write: Optional[int] = MAX_WRITE_TIMEOUT method: Literal['get', 'post', 'put', 'patch', 'delete', 'head'] url: str @@ -46,4 +50,5 @@ class HttpRequestNodeData(BaseNodeData): headers: str params: str body: Optional[Body] - timeout: Optional[Timeout] \ No newline at end of file + timeout: Optional[Timeout] + mask_authorization_header: Optional[bool] = True diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index c2beb7a383..10002216f1 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -1,4 +1,5 @@ import json +import os from copy import deepcopy from random import randint from typing import Any, Optional, Union @@ -13,10 +14,10 @@ from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.http_request.entities import HttpRequestNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser -MAX_BINARY_SIZE = 1024 * 1024 * 10 # 10MB -READABLE_MAX_BINARY_SIZE = '10MB' -MAX_TEXT_SIZE = 1024 * 1024 // 10 # 0.1MB -READABLE_MAX_TEXT_SIZE = '0.1MB' +MAX_BINARY_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_BINARY_SIZE', 1024 * 1024 * 10)) # 10MB +READABLE_MAX_BINARY_SIZE = f'{MAX_BINARY_SIZE / 1024 / 1024:.2f}MB' +MAX_TEXT_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_TEXT_SIZE', 1024 * 1024)) # 1MB +READABLE_MAX_TEXT_SIZE = f'{MAX_TEXT_SIZE / 1024 / 1024:.2f}MB' class HttpExecutorResponse: @@ -24,18 +25,10 @@ class HttpExecutorResponse: response: Union[httpx.Response, requests.Response] def __init__(self, response: Union[httpx.Response, requests.Response] = None): - """ - init - """ - headers = {} - if isinstance(response, httpx.Response): + self.headers = {} + if isinstance(response, httpx.Response | requests.Response): for k, v in response.headers.items(): - headers[k] = v - elif isinstance(response, requests.Response): - for k, v in response.headers.items(): - headers[k] = v - - self.headers = headers + self.headers[k] = v self.response = response @property @@ -45,21 +38,11 @@ class HttpExecutorResponse: """ content_type = self.get_content_type() file_content_types = ['image', 'audio', 'video'] - for v in file_content_types: - if v in content_type: - return True - - return False + + return any(v in content_type for v in file_content_types) def get_content_type(self) -> str: - """ - get content type - """ - for key, val in self.headers.items(): - if key.lower() == 'content-type': - return val - - return '' + return self.headers.get('content-type') def extract_file(self) -> tuple[str, bytes]: """ @@ -67,29 +50,25 @@ class HttpExecutorResponse: """ if self.is_file: return self.get_content_type(), self.body - + return '', b'' - + @property def content(self) -> str: """ get content """ - if isinstance(self.response, httpx.Response): - return self.response.text - elif isinstance(self.response, requests.Response): + if isinstance(self.response, httpx.Response | requests.Response): return self.response.text else: raise ValueError(f'Invalid response type {type(self.response)}') - + @property def body(self) -> bytes: """ get body """ - if isinstance(self.response, httpx.Response): - return self.response.content - elif isinstance(self.response, requests.Response): + if isinstance(self.response, httpx.Response | requests.Response): return self.response.content else: raise ValueError(f'Invalid response type {type(self.response)}') @@ -99,20 +78,18 @@ class HttpExecutorResponse: """ get status code """ - if isinstance(self.response, httpx.Response): - return self.response.status_code - elif isinstance(self.response, requests.Response): + if isinstance(self.response, httpx.Response | requests.Response): return self.response.status_code else: raise ValueError(f'Invalid response type {type(self.response)}') - + @property def size(self) -> int: """ get size """ return len(self.body) - + @property def readable_size(self) -> str: """ @@ -138,10 +115,8 @@ class HttpExecutor: variable_selectors: list[VariableSelector] timeout: HttpRequestNodeData.Timeout - def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout, variable_pool: Optional[VariablePool] = None): - """ - init - """ + def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout, + variable_pool: Optional[VariablePool] = None): self.server_url = node_data.url self.method = node_data.method self.authorization = node_data.authorization @@ -155,7 +130,8 @@ class HttpExecutor: self.variable_selectors = [] self._init_template(node_data, variable_pool) - def _is_json_body(self, body: HttpRequestNodeData.Body): + @staticmethod + def _is_json_body(body: HttpRequestNodeData.Body): """ check if body is json """ @@ -165,55 +141,48 @@ class HttpExecutor: return True except: return False - + return False + @staticmethod + def _to_dict(convert_item: str, convert_text: str, maxsplit: int = -1): + """ + Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` + :param convert_item: A label for what item to be converted, params, headers or body. + :param convert_text: The string containing key-value pairs separated by '\n'. + :param maxsplit: The maximum number of splits allowed for the ':' character in each key-value pair. Default is -1 (no limit). + :return: A dictionary containing the key-value pairs from the input string. + """ + kv_paris = convert_text.split('\n') + result = {} + for kv in kv_paris: + if not kv.strip(): + continue + + kv = kv.split(':', maxsplit=maxsplit) + if len(kv) >= 3: + k, v = kv[0], ":".join(kv[1:]) + elif len(kv) == 2: + k, v = kv + elif len(kv) == 1: + k, v = kv[0], '' + else: + raise ValueError(f'Invalid {convert_item} {kv}') + result[k.strip()] = v + return result + def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None): - """ - init template - """ - variable_selectors = [] # extract all template in url self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool) # extract all template in params params, params_variable_selectors = self._format_template(node_data.params, variable_pool) - - # fill in params - kv_paris = params.split('\n') - for kv in kv_paris: - if not kv.strip(): - continue - - kv = kv.split(':') - if len(kv) == 2: - k, v = kv - elif len(kv) == 1: - k, v = kv[0], '' - else: - raise ValueError(f'Invalid params {kv}') - - self.params[k.strip()] = v + self.params = self._to_dict("params", params) # extract all template in headers headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool) - - # fill in headers - kv_paris = headers.split('\n') - for kv in kv_paris: - if not kv.strip(): - continue - - kv = kv.split(':') - if len(kv) == 2: - k, v = kv - elif len(kv) == 1: - k, v = kv[0], '' - else: - raise ValueError(f'Invalid headers {kv}') - - self.headers[k.strip()] = v.strip() + self.headers = self._to_dict("headers", headers) # extract all template in body body_data_variable_selectors = [] @@ -231,18 +200,7 @@ class HttpExecutor: self.headers['Content-Type'] = 'application/x-www-form-urlencoded' if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: - body = {} - kv_paris = body_data.split('\n') - for kv in kv_paris: - if not kv.strip(): - continue - kv = kv.split(':') - if len(kv) == 2: - body[kv[0].strip()] = kv[1] - elif len(kv) == 1: - body[kv[0].strip()] = '' - else: - raise ValueError(f'Invalid body {kv}') + body = self._to_dict("body", body_data, 1) if node_data.body.type == 'form-data': self.files = { @@ -261,14 +219,14 @@ class HttpExecutor: self.variable_selectors = (server_url_variable_selectors + params_variable_selectors + headers_variable_selectors + body_data_variable_selectors) - + def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.authorization) headers = deepcopy(self.headers) or {} if self.authorization.type == 'api-key': if self.authorization.config.api_key is None: raise ValueError('api_key is required') - + if not self.authorization.config.header: authorization.config.header = 'Authorization' @@ -278,9 +236,9 @@ class HttpExecutor: headers[authorization.config.header] = f'Basic {authorization.config.api_key}' elif self.authorization.config.type == 'custom': headers[authorization.config.header] = authorization.config.api_key - + return headers - + def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse: """ validate the response @@ -289,21 +247,22 @@ class HttpExecutor: executor_response = HttpExecutorResponse(response) else: raise ValueError(f'Invalid response type {type(response)}') - + if executor_response.is_file: if executor_response.size > MAX_BINARY_SIZE: - raise ValueError(f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.') + raise ValueError( + f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.') else: if executor_response.size > MAX_TEXT_SIZE: - raise ValueError(f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.') - + raise ValueError( + f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.') + return executor_response - + def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: """ do http request depending on api bundle """ - # do http request kwargs = { 'url': self.server_url, 'headers': headers, @@ -312,25 +271,14 @@ class HttpExecutor: 'follow_redirects': True } - if self.method == 'get': - response = ssrf_proxy.get(**kwargs) - elif self.method == 'post': - response = ssrf_proxy.post(data=self.body, files=self.files, **kwargs) - elif self.method == 'put': - response = ssrf_proxy.put(data=self.body, files=self.files, **kwargs) - elif self.method == 'delete': - response = ssrf_proxy.delete(data=self.body, files=self.files, **kwargs) - elif self.method == 'patch': - response = ssrf_proxy.patch(data=self.body, files=self.files, **kwargs) - elif self.method == 'head': - response = ssrf_proxy.head(**kwargs) - elif self.method == 'options': - response = ssrf_proxy.options(**kwargs) + if self.method in ('get', 'head', 'options'): + response = getattr(ssrf_proxy, self.method)(**kwargs) + elif self.method in ('post', 'put', 'delete', 'patch'): + response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) else: raise ValueError(f'Invalid http method {self.method}') - return response - + def invoke(self) -> HttpExecutorResponse: """ invoke http request @@ -343,8 +291,8 @@ class HttpExecutor: # validate response return self._validate_and_parse_response(response) - - def to_raw_request(self) -> str: + + def to_raw_request(self, mask_authorization_header: Optional[bool] = True) -> str: """ convert to raw request """ @@ -356,6 +304,17 @@ class HttpExecutor: headers = self._assembling_headers() for k, v in headers.items(): + if mask_authorization_header: + # get authorization header + if self.authorization.type == 'api-key': + authorization_header = 'Authorization' + if self.authorization.config and self.authorization.config.header: + authorization_header = self.authorization.config.header + + if k.lower() == authorization_header.lower(): + raw_request += f'{k}: {"*" * len(v)}\n' + continue + raw_request += f'{k}: {v}\n' raw_request += '\n' diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index cba1a11a8a..d983a30695 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,5 +1,4 @@ import logging -import os from mimetypes import guess_extension from os import path from typing import cast @@ -9,14 +8,15 @@ from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.http_request.entities import HttpRequestNodeData +from core.workflow.nodes.http_request.entities import ( + MAX_CONNECT_TIMEOUT, + MAX_READ_TIMEOUT, + MAX_WRITE_TIMEOUT, + HttpRequestNodeData, +) from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse from models.workflow import WorkflowNodeExecutionStatus -MAX_CONNECT_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_CONNECT_TIMEOUT', '300')) -MAX_READ_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_READ_TIMEOUT', '600')) -MAX_WRITE_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_WRITE_TIMEOUT', '600')) - HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeData.Timeout(connect=min(10, MAX_CONNECT_TIMEOUT), read=min(60, MAX_READ_TIMEOUT), write=min(20, MAX_WRITE_TIMEOUT)) @@ -63,7 +63,9 @@ class HttpRequestNode(BaseNode): process_data = {} if http_executor: process_data = { - 'request': http_executor.to_raw_request(), + 'request': http_executor.to_raw_request( + mask_authorization_header=node_data.mask_authorization_header + ), } return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -82,7 +84,9 @@ class HttpRequestNode(BaseNode): 'files': files, }, process_data={ - 'request': http_executor.to_raw_request(), + 'request': http_executor.to_raw_request( + mask_authorization_header=node_data.mask_authorization_header + ), } ) @@ -91,8 +95,14 @@ class HttpRequestNode(BaseNode): if timeout is None: return HTTP_REQUEST_DEFAULT_TIMEOUT + if timeout.connect is None: + timeout.connect = HTTP_REQUEST_DEFAULT_TIMEOUT.connect timeout.connect = min(timeout.connect, MAX_CONNECT_TIMEOUT) + if timeout.read is None: + timeout.read = HTTP_REQUEST_DEFAULT_TIMEOUT.read timeout.read = min(timeout.read, MAX_READ_TIMEOUT) + if timeout.write is None: + timeout.write = HTTP_REQUEST_DEFAULT_TIMEOUT.write timeout.write = min(timeout.write, MAX_WRITE_TIMEOUT) return timeout diff --git a/api/core/workflow/nodes/iteration/__init__.py b/api/core/workflow/nodes/iteration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py new file mode 100644 index 0000000000..c85aa66c7b --- /dev/null +++ b/api/core/workflow/nodes/iteration/entities.py @@ -0,0 +1,39 @@ +from typing import Any, Optional + +from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState + + +class IterationNodeData(BaseIterationNodeData): + """ + Iteration Node Data. + """ + parent_loop_id: Optional[str] # redundant field, not used currently + iterator_selector: list[str] # variable selector + output_selector: list[str] # output selector + +class IterationState(BaseIterationState): + """ + Iteration State. + """ + outputs: list[Any] = None + current_output: Optional[Any] = None + + class MetaData(BaseIterationState.MetaData): + """ + Data. + """ + iterator_length: int + + def get_last_output(self) -> Optional[Any]: + """ + Get last output. + """ + if self.outputs: + return self.outputs[-1] + return None + + def get_current_output(self) -> Optional[Any]: + """ + Get current output. + """ + return self.current_output \ No newline at end of file diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py new file mode 100644 index 0000000000..12d792f297 --- /dev/null +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -0,0 +1,119 @@ +from typing import cast + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.base_node_data_entities import BaseIterationState +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseIterationNode +from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState +from models.workflow import WorkflowNodeExecutionStatus + + +class IterationNode(BaseIterationNode): + """ + Iteration Node. + """ + _node_data_cls = IterationNodeData + _node_type = NodeType.ITERATION + + def _run(self, variable_pool: VariablePool) -> BaseIterationState: + """ + Run the node. + """ + iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector) + + if not isinstance(iterator, list): + raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.") + + state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={ + 'iterator_selector': iterator + }, outputs=[], metadata=IterationState.MetaData( + iterator_length=len(iterator) if iterator is not None else 0 + )) + + self._set_current_iteration_variable(variable_pool, state) + return state + + def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str: + """ + Get next iteration start node id based on the graph. + :param graph: graph + :return: next node id + """ + # resolve current output + self._resolve_current_output(variable_pool, state) + # move to next iteration + self._next_iteration(variable_pool, state) + + node_data = cast(IterationNodeData, self.node_data) + if self._reached_iteration_limit(variable_pool, state): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + 'output': jsonable_encoder(state.outputs) + } + ) + + return node_data.start_node_id + + def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState): + """ + Set current iteration variable. + :variable_pool: variable pool + """ + node_data = cast(IterationNodeData, self.node_data) + + variable_pool.append_variable(self.node_id, ['index'], state.index) + # get the iterator value + iterator = variable_pool.get_variable_value(node_data.iterator_selector) + + if iterator is None or not isinstance(iterator, list): + return + + if state.index < len(iterator): + variable_pool.append_variable(self.node_id, ['item'], iterator[state.index]) + + def _next_iteration(self, variable_pool: VariablePool, state: IterationState): + """ + Move to next iteration. + :param variable_pool: variable pool + """ + state.index += 1 + self._set_current_iteration_variable(variable_pool, state) + + def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState): + """ + Check if iteration limit is reached. + :return: True if iteration limit is reached, False otherwise + """ + node_data = cast(IterationNodeData, self.node_data) + iterator = variable_pool.get_variable_value(node_data.iterator_selector) + + if iterator is None or not isinstance(iterator, list): + return True + + return state.index >= len(iterator) + + def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState): + """ + Resolve current output. + :param variable_pool: variable pool + """ + output_selector = cast(IterationNodeData, self.node_data).output_selector + output = variable_pool.get_variable_value(output_selector) + # clear the output for this iteration + variable_pool.append_variable(self.node_id, output_selector[1:], None) + state.current_output = output + if output is not None: + state.outputs.append(output) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return { + 'input_selector': node_data.iterator_selector, + } \ No newline at end of file diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index be3cec9152..1a0f3b0495 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,5 +1,7 @@ from typing import Any, cast +from sqlalchemy import func + from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy @@ -73,30 +75,33 @@ class KnowledgeRetrievalNode(BaseNode): def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[ dict[str, Any]]: - """ - A dataset tool is a tool that can be used to retrieve information from a dataset - :param node_data: node data - :param query: query - """ - tools = [] available_datasets = [] dataset_ids = node_data.dataset_ids - for dataset_id in dataset_ids: - # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == dataset_id - ).first() + # Subquery: Count the number of available documents for each dataset + subquery = db.session.query( + Document.dataset_id, + func.count(Document.id).label('available_document_count') + ).filter( + Document.indexing_status == 'completed', + Document.enabled == True, + Document.archived == False, + Document.dataset_id.in_(dataset_ids) + ).group_by(Document.dataset_id).having( + func.count(Document.id) > 0 + ).subquery() + + results = db.session.query(Dataset).join( + subquery, Dataset.id == subquery.c.dataset_id + ).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id.in_(dataset_ids) + ).all() + + for dataset in results: # pass if dataset is not available if not dataset: continue - - # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): - continue - available_datasets.append(dataset) all_documents = [] dataset_retrieval = DatasetRetrieval() @@ -143,10 +148,9 @@ class KnowledgeRetrievalNode(BaseNode): if all_documents: document_score_list = {} for item in all_documents: - if 'score' in item.metadata and item.metadata['score']: + if item.metadata.get('score'): document_score_list[item.metadata['doc_id']] = item.metadata['score'] - document_context_list = [] index_node_ids = [document.metadata['doc_id'] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), @@ -160,11 +164,6 @@ class KnowledgeRetrievalNode(BaseNode): sorted_segments = sorted(segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float('inf'))) - for segment in sorted_segments: - if segment.answer: - document_context_list.append(f'question:{segment.content} answer:{segment.answer}') - else: - document_context_list.append(segment.content) for segment in sorted_segments: dataset = Dataset.query.filter_by( @@ -197,9 +196,9 @@ class KnowledgeRetrievalNode(BaseNode): 'title': document.name } if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}' else: - source['content'] = segment.content + source['content'] = segment.get_sign_content() context_list.append(source) resource_number += 1 return context_list diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index c390aaf8c9..1e48a10bc7 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector class ModelConfig(BaseModel): @@ -37,13 +38,31 @@ class VisionConfig(BaseModel): enabled: bool configs: Optional[Configs] = None +class PromptConfig(BaseModel): + """ + Prompt Config. + """ + jinja2_variables: Optional[list[VariableSelector]] = None + +class LLMNodeChatModelMessage(ChatModelMessage): + """ + LLM Node Chat Model Message. + """ + jinja2_text: Optional[str] = None + +class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): + """ + LLM Node Chat Model Prompt Template. + """ + jinja2_text: Optional[str] = None class LLMNodeData(BaseNodeData): """ LLM Node Data. """ model: ModelConfig - prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate] + prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate] + prompt_config: Optional[PromptConfig] = None memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index c8b7f279ab..fef09c1385 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,4 +1,6 @@ +import json from collections.abc import Generator +from copy import deepcopy from typing import Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -17,11 +19,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig +from core.workflow.nodes.llm.entities import ( + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + LLMNodeData, + ModelConfig, +) from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models.model import Conversation @@ -39,16 +45,24 @@ class LLMNode(BaseNode): :param variable_pool: variable pool :return: """ - node_data = self.node_data - node_data = cast(self._node_data_cls, node_data) + node_data = cast(LLMNodeData, deepcopy(self.node_data)) node_inputs = None process_data = None try: + # init messages template + node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template) + # fetch variables and fetch values from variable pool inputs = self._fetch_inputs(node_data, variable_pool) + # fetch jinja2 inputs + jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool) + + # merge inputs + inputs.update(jinja_inputs) + node_inputs = {} # fetch files @@ -183,6 +197,86 @@ class LLMNode(BaseNode): usage = LLMUsage.empty_usage() return full_text, usage + + def _transform_chat_messages(self, + messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: + """ + Transform chat messages + + :param messages: chat messages + :return: + """ + + if isinstance(messages, LLMNodeCompletionModelPromptTemplate): + if messages.edition_type == 'jinja2': + messages.text = messages.jinja2_text + + return messages + + for message in messages: + if message.edition_type == 'jinja2': + message.text = message.jinja2_text + + return messages + + def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: + """ + Fetch jinja inputs + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + variables = {} + + if not node_data.prompt_config: + return variables + + for variable_selector in node_data.prompt_config.jinja2_variables or []: + variable = variable_selector.variable + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + + def parse_dict(d: dict) -> str: + """ + Parse dict into string + """ + # check if it's a context structure + if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: + return d['content'] + + # else, parse the dict + try: + return json.dumps(d, ensure_ascii=False) + except Exception: + return str(d) + + if isinstance(value, str): + value = value + elif isinstance(value, list): + result = '' + for item in value: + if isinstance(item, dict): + result += parse_dict(item) + elif isinstance(item, str): + result += item + elif isinstance(item, int | float): + result += str(item) + else: + result += str(item) + result += '\n' + value = result.strip() + elif isinstance(value, dict): + value = parse_dict(value) + elif isinstance(value, int | float): + value = str(value) + else: + value = str(value) + + variables[variable] = value + + return variables def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: """ @@ -531,25 +625,25 @@ class LLMNode(BaseNode): db.session.commit() @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) prompt_template = node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + if prompt.edition_type != 'jinja2': + variable_template_parser = VariableTemplateParser(template=prompt.text) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) else: - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() + if prompt_template.edition_type != 'jinja2': + variable_template_parser = VariableTemplateParser(template=prompt_template.text) + variable_selectors = variable_template_parser.extract_variable_selectors() variable_mapping = {} for variable_selector in variable_selectors: @@ -571,6 +665,22 @@ class LLMNode(BaseNode): if node_data.memory: variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] + if node_data.prompt_config: + enable_jinja = False + + if isinstance(prompt_template, list): + for prompt in prompt_template: + if prompt.edition_type == 'jinja2': + enable_jinja = True + break + else: + if prompt_template.edition_type == 'jinja2': + enable_jinja = True + + if enable_jinja: + for variable_selector in node_data.prompt_config.jinja2_variables or []: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + return variable_mapping @classmethod @@ -588,7 +698,8 @@ class LLMNode(BaseNode): "prompts": [ { "role": "system", - "text": "You are a helpful AI assistant." + "text": "You are a helpful AI assistant.", + "edition_type": "basic" } ] }, @@ -600,7 +711,8 @@ class LLMNode(BaseNode): "prompt": { "text": "Here is the chat histories between human and assistant, inside " " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:" + "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", + "edition_type": "basic" }, "stop": ["Human:"] } diff --git a/api/core/workflow/nodes/loop/__init__.py b/api/core/workflow/nodes/loop/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py new file mode 100644 index 0000000000..8a5684551e --- /dev/null +++ b/api/core/workflow/nodes/loop/entities.py @@ -0,0 +1,13 @@ + +from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState + + +class LoopNodeData(BaseIterationNodeData): + """ + Loop Node Data. + """ + +class LoopState(BaseIterationState): + """ + Loop State. + """ \ No newline at end of file diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py new file mode 100644 index 0000000000..7d53c6f5f2 --- /dev/null +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -0,0 +1,20 @@ +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseIterationNode +from core.workflow.nodes.loop.entities import LoopNodeData, LoopState + + +class LoopNode(BaseIterationNode): + """ + Loop Node. + """ + _node_data_cls = LoopNodeData + _node_type = NodeType.LOOP + + def _run(self, variable_pool: VariablePool) -> LoopState: + return super()._run(variable_pool) + + def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str: + """ + Get next iteration start node id based on the graph. + """ diff --git a/api/core/workflow/nodes/parameter_extractor/__init__.py b/api/core/workflow/nodes/parameter_extractor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py new file mode 100644 index 0000000000..a89a6903ef --- /dev/null +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -0,0 +1,85 @@ +from typing import Any, Literal, Optional + +from pydantic import BaseModel, validator + +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class ModelConfig(BaseModel): + """ + Model Config. + """ + provider: str + name: str + mode: str + completion_params: dict[str, Any] = {} + +class ParameterConfig(BaseModel): + """ + Parameter Config. + """ + name: str + type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]'] + options: Optional[list[str]] + description: str + required: bool + + @validator('name', pre=True, always=True) + def validate_name(cls, value): + if not value: + raise ValueError('Parameter name is required') + if value in ['__reason', '__is_success']: + raise ValueError('Invalid parameter name, __reason and __is_success are reserved') + return value + +class ParameterExtractorNodeData(BaseNodeData): + """ + Parameter Extractor Node Data. + """ + model: ModelConfig + query: list[str] + parameters: list[ParameterConfig] + instruction: Optional[str] + memory: Optional[MemoryConfig] + reasoning_mode: Literal['function_call', 'prompt'] + + @validator('reasoning_mode', pre=True, always=True) + def set_reasoning_mode(cls, v): + return v or 'function_call' + + def get_parameter_json_schema(self) -> dict: + """ + Get parameter json schema. + + :return: parameter json schema + """ + parameters = { + 'type': 'object', + 'properties': {}, + 'required': [] + } + + for parameter in self.parameters: + parameter_schema = { + 'description': parameter.description + } + + if parameter.type in ['string', 'select']: + parameter_schema['type'] = 'string' + elif parameter.type.startswith('array'): + parameter_schema['type'] = 'array' + nested_type = parameter.type[6:-1] + parameter_schema['items'] = {'type': nested_type} + else: + parameter_schema['type'] = parameter.type + + if parameter.type == 'select': + parameter_schema['enum'] = parameter.options + + parameters['properties'][parameter.name] = parameter_schema + + if parameter.required: + parameters['required'].append(parameter.name) + + return parameters \ No newline at end of file diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py new file mode 100644 index 0000000000..6e7dbd2702 --- /dev/null +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -0,0 +1,711 @@ +import json +import uuid +from typing import Optional, cast + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.llm.entities import ModelConfig +from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from core.workflow.nodes.parameter_extractor.prompts import ( + CHAT_EXAMPLE, + CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, + COMPLETION_GENERATE_JSON_PROMPT, + FUNCTION_CALLING_EXTRACTOR_EXAMPLE, + FUNCTION_CALLING_EXTRACTOR_NAME, + FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT, + FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus + + +class ParameterExtractorNode(LLMNode): + """ + Parameter Extractor Node. + """ + _node_data_cls = ParameterExtractorNodeData + _node_type = NodeType.PARAMETER_EXTRACTOR + + _model_instance: Optional[ModelInstance] = None + _model_config: Optional[ModelConfigWithCredentialsEntity] = None + + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + return { + "model": { + "prompt_templates": { + "completion_model": { + "conversation_histories_role": { + "user_prefix": "Human", + "assistant_prefix": "Assistant" + }, + "stop": ["Human:"] + } + } + } + } + + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run the node. + """ + + node_data = cast(ParameterExtractorNodeData, self.node_data) + query = variable_pool.get_variable_value(node_data.query) + if not query: + raise ValueError("Query not found") + + inputs={ + 'query': query, + 'parameters': jsonable_encoder(node_data.parameters), + 'instruction': jsonable_encoder(node_data.instruction), + } + + model_instance, model_config = self._fetch_model_config(node_data.model) + if not isinstance(model_instance.model_type_instance, LargeLanguageModel): + raise ValueError("Model is not a Large Language Model") + + llm_model = model_instance.model_type_instance + model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) + if not model_schema: + raise ValueError("Model schema not found") + + # fetch memory + memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + + if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \ + and node_data.reasoning_mode == 'function_call': + # use function call + prompt_messages, prompt_message_tools = self._generate_function_call_prompt( + node_data, query, variable_pool, model_config, memory + ) + else: + # use prompt engineering + prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, memory) + prompt_message_tools = [] + + process_data = { + 'model_mode': model_config.mode, + 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, + prompt_messages=prompt_messages + ), + 'usage': None, + 'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), + 'tool_call': None, + } + + try: + text, usage, tool_call = self._invoke_llm( + node_data_model=node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + tools=prompt_message_tools, + stop=model_config.stop, + ) + process_data['usage'] = jsonable_encoder(usage) + process_data['tool_call'] = jsonable_encoder(tool_call) + process_data['llm_text'] = text + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data={}, + outputs={ + '__is_success': 0, + '__reason': str(e) + }, + error=str(e), + metadata={} + ) + + error = None + + if tool_call: + result = self._extract_json_from_tool_call(tool_call) + else: + result = self._extract_complete_json_response(text) + if not result: + result = self._generate_default_result(node_data) + error = "Failed to extract result from function call or text response, using empty result." + + try: + result = self._validate_result(node_data, result) + except Exception as e: + error = str(e) + + # transform result into standard format + result = self._transform_result(node_data, result) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={ + '__is_success': 1 if not error else 0, + '__reason': error, + **result + }, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency + } + ) + + def _invoke_llm(self, node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: + """ + Invoke large language model + :param node_data_model: node data model + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: + """ + db.session.close() + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=node_data_model.completion_params, + tools=tools, + stop=stop, + stream=False, + user=self.user_id, + ) + + # handle invoke result + if not isinstance(invoke_result, LLMResult): + raise ValueError(f"Invalid invoke result: {invoke_result}") + + text = invoke_result.message.content + usage = invoke_result.usage + tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None + + # deduct quota + self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + + return text, usage, tool_call + + def _generate_function_call_prompt(self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: + """ + Generate function call prompt. + """ + query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(node_data.get_parameter_json_schema())) + + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') + prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, rest_token) + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={}, + query='', + files=[], + context='', + memory_config=node_data.memory, + memory=None, + model_config=model_config + ) + + # find last user message + last_user_message_idx = -1 + for i, prompt_message in enumerate(prompt_messages): + if prompt_message.role == PromptMessageRole.USER: + last_user_message_idx = i + + # add function call messages before last user message + example_messages = [] + for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: + id = uuid.uuid4().hex + example_messages.extend([ + UserPromptMessage(content=example['user']['query']), + AssistantPromptMessage( + content=example['assistant']['text'], + tool_calls=[ + AssistantPromptMessage.ToolCall( + id=id, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=example['assistant']['function_call']['name'], + arguments=json.dumps(example['assistant']['function_call']['parameters'] + ) + )) + ] + ), + ToolPromptMessage( + content='Great! You have called the function with the correct parameters.', + tool_call_id=id + ), + AssistantPromptMessage( + content='I have extracted the parameters, let\'s move on.', + ) + ]) + + prompt_messages = prompt_messages[:last_user_message_idx] + \ + example_messages + prompt_messages[last_user_message_idx:] + + # generate tool + tool = PromptMessageTool( + name=FUNCTION_CALLING_EXTRACTOR_NAME, + description='Extract parameters from the natural language text', + parameters=node_data.get_parameter_json_schema(), + ) + + return prompt_messages, [tool] + + def _generate_prompt_engineering_prompt(self, + data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: + """ + Generate prompt engineering prompt. + """ + model_mode = ModelMode.value_of(data.model.mode) + + if model_mode == ModelMode.COMPLETION: + return self._generate_prompt_engineering_completion_prompt( + data, query, variable_pool, model_config, memory + ) + elif model_mode == ModelMode.CHAT: + return self._generate_prompt_engineering_chat_prompt( + data, query, variable_pool, model_config, memory + ) + else: + raise ValueError(f"Invalid model mode: {model_mode}") + + def _generate_prompt_engineering_completion_prompt(self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: + """ + Generate completion prompt. + """ + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') + prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, rest_token) + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={ + 'structure': json.dumps(node_data.get_parameter_json_schema()) + }, + query='', + files=[], + context='', + memory_config=node_data.memory, + memory=memory, + model_config=model_config + ) + + return prompt_messages + + def _generate_prompt_engineering_chat_prompt(self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: + """ + Generate chat prompt. + """ + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') + prompt_template = self._get_prompt_engineering_prompt_template( + node_data, + CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + structure=json.dumps(node_data.get_parameter_json_schema()), + text=query + ), + variable_pool, memory, rest_token + ) + + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={}, + query='', + files=[], + context='', + memory_config=node_data.memory, + memory=memory, + model_config=model_config + ) + + # find last user message + last_user_message_idx = -1 + for i, prompt_message in enumerate(prompt_messages): + if prompt_message.role == PromptMessageRole.USER: + last_user_message_idx = i + + # add example messages before last user message + example_messages = [] + for example in CHAT_EXAMPLE: + example_messages.extend([ + UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + structure=json.dumps(example['user']['json']), + text=example['user']['query'], + )), + AssistantPromptMessage( + content=json.dumps(example['assistant']['json']), + ) + ]) + + prompt_messages = prompt_messages[:last_user_message_idx] + \ + example_messages + prompt_messages[last_user_message_idx:] + + return prompt_messages + + def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: + """ + Validate result. + """ + if len(data.parameters) != len(result): + raise ValueError("Invalid number of parameters") + + for parameter in data.parameters: + if parameter.required and parameter.name not in result: + raise ValueError(f"Parameter {parameter.name} is required") + + if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options: + raise ValueError(f"Invalid `select` value for parameter {parameter.name}") + + if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float): + raise ValueError(f"Invalid `number` value for parameter {parameter.name}") + + if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool): + raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") + + if parameter.type == 'string' and not isinstance(result.get(parameter.name), str): + raise ValueError(f"Invalid `string` value for parameter {parameter.name}") + + if parameter.type.startswith('array'): + if not isinstance(result.get(parameter.name), list): + raise ValueError(f"Invalid `array` value for parameter {parameter.name}") + nested_type = parameter.type[6:-1] + for item in result.get(parameter.name): + if nested_type == 'number' and not isinstance(item, int | float): + raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") + if nested_type == 'string' and not isinstance(item, str): + raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") + if nested_type == 'object' and not isinstance(item, dict): + raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") + return result + + def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: + """ + Transform result into standard format. + """ + transformed_result = {} + for parameter in data.parameters: + if parameter.name in result: + # transform value + if parameter.type == 'number': + if isinstance(result[parameter.name], int | float): + transformed_result[parameter.name] = result[parameter.name] + elif isinstance(result[parameter.name], str): + try: + if '.' in result[parameter.name]: + result[parameter.name] = float(result[parameter.name]) + else: + result[parameter.name] = int(result[parameter.name]) + except ValueError: + pass + else: + pass + # TODO: bool is not supported in the current version + # elif parameter.type == 'bool': + # if isinstance(result[parameter.name], bool): + # transformed_result[parameter.name] = bool(result[parameter.name]) + # elif isinstance(result[parameter.name], str): + # if result[parameter.name].lower() in ['true', 'false']: + # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') + # elif isinstance(result[parameter.name], int): + # transformed_result[parameter.name] = bool(result[parameter.name]) + elif parameter.type in ['string', 'select']: + if isinstance(result[parameter.name], str): + transformed_result[parameter.name] = result[parameter.name] + elif parameter.type.startswith('array'): + if isinstance(result[parameter.name], list): + nested_type = parameter.type[6:-1] + transformed_result[parameter.name] = [] + for item in result[parameter.name]: + if nested_type == 'number': + if isinstance(item, int | float): + transformed_result[parameter.name].append(item) + elif isinstance(item, str): + try: + if '.' in item: + transformed_result[parameter.name].append(float(item)) + else: + transformed_result[parameter.name].append(int(item)) + except ValueError: + pass + elif nested_type == 'string': + if isinstance(item, str): + transformed_result[parameter.name].append(item) + elif nested_type == 'object': + if isinstance(item, dict): + transformed_result[parameter.name].append(item) + + if parameter.name not in transformed_result: + if parameter.type == 'number': + transformed_result[parameter.name] = 0 + elif parameter.type == 'bool': + transformed_result[parameter.name] = False + elif parameter.type in ['string', 'select']: + transformed_result[parameter.name] = '' + elif parameter.type.startswith('array'): + transformed_result[parameter.name] = [] + + return transformed_result + + def _extract_complete_json_response(self, result: str) -> Optional[dict]: + """ + Extract complete json response. + """ + def extract_json(text): + """ + From a given JSON started from '{' or '[' extract the complete JSON object. + """ + stack = [] + for i, c in enumerate(text): + if c == '{' or c == '[': + stack.append(c) + elif c == '}' or c == ']': + # check if stack is empty + if not stack: + return text[:i] + # check if the last element in stack is matching + if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['): + stack.pop() + if not stack: + return text[:i+1] + else: + return text[:i] + return None + + # extract json from the text + for idx in range(len(result)): + if result[idx] == '{' or result[idx] == '[': + json_str = extract_json(result[idx:]) + if json_str: + try: + return json.loads(json_str) + except Exception: + pass + + def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: + """ + Extract json from tool call. + """ + if not tool_call or not tool_call.function.arguments: + return None + + return json.loads(tool_call.function.arguments) + + def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: + """ + Generate default result. + """ + result = {} + for parameter in data.parameters: + if parameter.type == 'number': + result[parameter.name] = 0 + elif parameter.type == 'bool': + result[parameter.name] = False + elif parameter.type in ['string', 'select']: + result[parameter.name] = '' + + return result + + def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str: + """ + Render instruction. + """ + variable_template_parser = VariableTemplateParser(instruction) + inputs = {} + for selector in variable_template_parser.extract_variable_selectors(): + inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector) + + return variable_template_parser.format(inputs) + + def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000) \ + -> list[ChatModelMessage]: + model_mode = ModelMode.value_of(node_data.model.mode) + input_text = query + memory_str = '' + instruction = self._render_instruction(node_data.instruction or '', variable_pool) + + if memory: + memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, + message_limit=node_data.memory.window.size) + if model_mode == ModelMode.CHAT: + system_prompt_messages = ChatModelMessage( + role=PromptMessageRole.SYSTEM, + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) + ) + user_prompt_message = ChatModelMessage( + role=PromptMessageRole.USER, + text=input_text + ) + return [system_prompt_messages, user_prompt_message] + else: + raise ValueError(f"Model mode {model_mode} not support.") + + def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000) \ + -> list[ChatModelMessage]: + + model_mode = ModelMode.value_of(node_data.model.mode) + input_text = query + memory_str = '' + instruction = self._render_instruction(node_data.instruction or '', variable_pool) + + if memory: + memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, + message_limit=node_data.memory.window.size) + if model_mode == ModelMode.CHAT: + system_prompt_messages = ChatModelMessage( + role=PromptMessageRole.SYSTEM, + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) + ) + user_prompt_message = ChatModelMessage( + role=PromptMessageRole.USER, + text=input_text + ) + return [system_prompt_messages, user_prompt_message] + elif model_mode == ModelMode.COMPLETION: + return CompletionModelPromptTemplate( + text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str, + text=input_text, + instruction=instruction) + .replace('{γγγ', '') + .replace('}γγγ', '') + ) + else: + raise ValueError(f"Model mode {model_mode} not support.") + + def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str]) -> int: + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + + model_instance, model_config = self._fetch_model_config(node_data.model) + if not isinstance(model_instance.model_type_instance, LargeLanguageModel): + raise ValueError("Model is not a Large Language Model") + + llm_model = model_instance.model_type_instance + model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) + if not model_schema: + raise ValueError("Model schema not found") + + if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]): + prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) + else: + prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) + + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={}, + query='', + files=[], + context=context, + memory_config=node_data.memory, + memory=None, + model_config=model_config + ) + rest_tokens = 2000 + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + curr_message_tokens = model_type_instance.get_num_tokens( + model_config.model, + model_config.credentials, + prompt_messages + ) + 1000 # add 1000 to ensure tool call messages + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if (parameter_rule.name == 'max_tokens' + or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): + max_tokens = (model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template)) or 0 + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config. + """ + if not self._model_instance or not self._model_config: + self._model_instance, self._model_config = super()._fetch_model_config(node_data_model) + + return self._model_instance, self._model_config + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + node_data = node_data + + variable_mapping = { + 'query': node_data.query + } + + if node_data.instruction: + variable_template_parser = VariableTemplateParser(template=node_data.instruction) + for selector in variable_template_parser.extract_variable_selectors(): + variable_mapping[selector.variable] = selector.value_selector + + return variable_mapping \ No newline at end of file diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py new file mode 100644 index 0000000000..499c58d505 --- /dev/null +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -0,0 +1,206 @@ +FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters' + +FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. +### Task +Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria. +### Memory +Here is the chat history between the human and assistant, provided within tags: + +\x7bhistories\x7d + +### Instructions: +Some additional information is provided below. Always adhere to these instructions as closely as possible: + +\x7binstruction\x7d + +Steps: +1. Review the chat history provided within the tags. +2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text. +3. Generate a well-formatted output using the defined functions and arguments. +4. Use the `extract_parameter` function to create structured outputs with appropriate parameters. +5. Do not include any XML tags in your output. +### Example +To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples. +### Final Output +Produce well-formatted function calls in json without XML tags, as shown in the example. +""" + +FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside XML tags. + +\x7bcontent\x7d + + + +\x7bstructure\x7d + +""" + +FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{ + 'user': { + 'query': 'What is the weather today in SF?', + 'function': { + 'name': FUNCTION_CALLING_EXTRACTOR_NAME, + 'parameters': { + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': 'The location to get the weather information', + 'required': True + }, + }, + 'required': ['location'] + } + } + }, + 'assistant': { + 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.', + 'function_call' : { + 'name': FUNCTION_CALLING_EXTRACTOR_NAME, + 'parameters': { + 'location': 'San Francisco' + } + } + } +}, { + 'user': { + 'query': 'I want to eat some apple pie.', + 'function': { + 'name': FUNCTION_CALLING_EXTRACTOR_NAME, + 'parameters': { + 'type': 'object', + 'properties': { + 'food': { + 'type': 'string', + 'description': 'The food to eat', + 'required': True + } + }, + 'required': ['food'] + } + } + }, + 'assistant': { + 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.', + 'function_call' : { + 'name': FUNCTION_CALLING_EXTRACTOR_NAME, + 'parameters': { + 'food': 'apple pie' + } + } + } +}] + +COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: +Some extra information are provided below, I should always follow the instructions as possible as I can. + +{instruction} + + +### Extract parameter Workflow +I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted. + +{{ structure }} + + +Step 1: Carefully read the input and understand the structure of the expected output. +Step 2: Extract relevant parameters from the provided text based on the name and description of object. +Step 3: Structure the extracted parameters to JSON object as specified in . +Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted. + +### Memory +Here is the chat histories between human and assistant, inside XML tags. + +{histories} + + +### Structure +Here is the structure of the expected output, I should always follow the output structure. +{{γγγ + 'properties1': 'relevant text extracted from input', + 'properties2': 'relevant text extracted from input', +}}γγγ + +### Input Text +Inside XML tags, there is a text that I should extract parameters and convert to a JSON object. + +{text} + + +### Answer +I should always output a valid JSON object. Output nothing other than the JSON object. +```JSON +""" + +CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object. +The structure of the JSON object you can found in the instructions. + +### Memory +Here is the chat histories between human and assistant, inside XML tags. + +{histories} + + +### Instructions: +Some extra information are provided below, you should always follow the instructions as possible as you can. + +{{instructions}} + +""" + +CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure +Here is the structure of the JSON object, you should always follow the structure. + +{structure} + + +### Text to be converted to JSON +Inside XML tags, there is a text that you should convert to a JSON object. + +{text} + +""" + +CHAT_EXAMPLE = [{ + 'user': { + 'query': 'What is the weather today in SF?', + 'json': { + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': 'The location to get the weather information', + 'required': True + } + }, + 'required': ['location'] + } + }, + 'assistant': { + 'text': 'I need to output a valid JSON object.', + 'json': { + 'location': 'San Francisco' + } + } +}, { + 'user': { + 'query': 'I want to eat some apple pie.', + 'json': { + 'type': 'object', + 'properties': { + 'food': { + 'type': 'string', + 'description': 'The food to eat', + 'required': True + } + }, + 'required': ['food'] + } + }, + 'assistant': { + 'text': 'I need to output a valid JSON object.', + 'json': { + 'result': 'apple pie' + } + } +}] \ No newline at end of file diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 9ec0df721c..76f3dec836 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,16 +1,18 @@ +import json import logging from typing import Optional, Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool @@ -25,6 +27,7 @@ from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_2, QUESTION_CLASSIFIER_USER_PROMPT_3, ) +from core.workflow.utils.variable_template_parser import VariableTemplateParser from libs.json_in_md_parser import parse_and_check_json_markdown from models.workflow import WorkflowNodeExecutionStatus @@ -46,6 +49,9 @@ class QuestionClassifierNode(LLMNode): model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + # fetch instruction + instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else '' + node_data.instruction = instruction # fetch prompt messages prompt_messages, stop = self._fetch_prompt( node_data=node_data, @@ -62,13 +68,20 @@ class QuestionClassifierNode(LLMNode): prompt_messages=prompt_messages, stop=stop ) - categories = [_class.name for _class in node_data.classes] + category_name = node_data.classes[0].name + category_id = node_data.classes[0].id try: result_text_json = parse_and_check_json_markdown(result_text, []) - #result_text_json = json.loads(result_text.strip('```JSON\n')) - categories_result = result_text_json.get('categories', []) - if categories_result: - categories = categories_result + # result_text_json = json.loads(result_text.strip('```JSON\n')) + if 'category_name' in result_text_json and 'category_id' in result_text_json: + category_id_result = result_text_json['category_id'] + classes = node_data.classes + classes_map = {class_.id: class_.name for class_ in classes} + category_ids = [_class.id for _class in classes] + if category_id_result in category_ids: + category_name = classes_map[category_id_result] + category_id = category_id_result + except Exception: logging.error(f"Failed to parse result text: {result_text}") try: @@ -81,17 +94,15 @@ class QuestionClassifierNode(LLMNode): 'usage': jsonable_encoder(usage), } outputs = { - 'class_name': categories[0] if categories else '' + 'class_name': category_name } - classes = node_data.classes - classes_map = {class_.name: class_.id for class_ in classes} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=process_data, outputs=outputs, - edge_source_handle=classes_map.get(categories[0], None), + edge_source_handle=category_id, metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, @@ -116,6 +127,12 @@ class QuestionClassifierNode(LLMNode): node_data = node_data node_data = cast(cls._node_data_cls, node_data) variable_mapping = {'query': node_data.query_variable_selector} + variable_selectors = [] + if node_data.instruction: + variable_template_parser = VariableTemplateParser(template=node_data.instruction) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + for variable_selector in variable_selectors: + variable_mapping[variable_selector.variable] = variable_selector.value_selector return variable_mapping @classmethod @@ -183,12 +200,12 @@ class QuestionClassifierNode(LLMNode): model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model + ) - curr_message_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, + curr_message_tokens = model_instance.get_llm_num_tokens( prompt_messages ) @@ -210,8 +227,13 @@ class QuestionClassifierNode(LLMNode): -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes - class_names = [class_.name for class_ in classes] - class_names_str = ','.join(f'"{name}"' for name in class_names) + categories = [] + for class_ in classes: + category = { + 'category_id': class_.id, + 'category_name': class_.name + } + categories.append(category) instruction = node_data.instruction if node_data.instruction else '' input_text = query memory_str = '' @@ -248,7 +270,7 @@ class QuestionClassifierNode(LLMNode): user_prompt_message_3 = ChatModelMessage( role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text, - categories=class_names_str, + categories=json.dumps(categories, ensure_ascii=False), classification_instructions=instruction) ) prompt_messages.append(user_prompt_message_3) @@ -257,9 +279,31 @@ class QuestionClassifierNode(LLMNode): return CompletionModelPromptTemplate( text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str, input_text=input_text, - categories=class_names_str, - classification_instructions=instruction) + categories=json.dumps(categories), + classification_instructions=instruction, + ensure_ascii=False) ) else: raise ValueError(f"Model mode {model_mode} not support.") + + def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str: + inputs = {} + + variable_selectors = [] + variable_template_parser = VariableTemplateParser(template=instruction) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + for variable_selector in variable_selectors: + variable_value = variable_pool.get_variable_value(variable_selector.value_selector) + if variable_value is None: + raise ValueError(f'Variable {variable_selector.variable} not found') + + inputs[variable_selector.variable] = variable_value + + prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + instruction = prompt_template.format( + prompt_inputs + ) + return instruction diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py index 5bef0250e3..23bd05a809 100644 --- a/api/core/workflow/nodes/question_classifier/template_prompts.py +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -6,7 +6,7 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ ### Task Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output.Additionally, you need to extract the key words from the text that are related to the classification. ### Format - The input text is in the variable text_field.Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination.Classification instructions may be included to improve the classification accuracy. + The input text is in the variable text_field.Categories are specified as a category list with two filed category_id and category_name in the variable categories .Classification instructions may be included to improve the classification accuracy. ### Constraint DO NOT include anything other than the JSON array in your response. ### Memory @@ -18,33 +18,35 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ QUESTION_CLASSIFIER_USER_PROMPT_1 = """ { "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], - "categories": ["Customer Service", "Satisfaction", "Sales", "Product"], + "categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}], "classification_instructions": ["classify the text based on the feedback provided by customer"]} """ QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """ ```json {"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"], - "categories": ["Customer Service"]} + "category_id": "f5660049-284f-41a7-b301-fd24176a711c", + "category_name": "Customer Service"} ``` """ QUESTION_CLASSIFIER_USER_PROMPT_2 = """ {"input_text": ["bad service, slow to bring the food"], - "categories": ["Food Quality", "Experience", "Price" ], + "categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}], "classification_instructions": []} """ QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ ```json {"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"], - "categories": ["Experience"]} + "category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f", + "category_name": "Experience"} ``` """ QUESTION_CLASSIFIER_USER_PROMPT_3 = """ '{{"input_text": ["{input_text}"],', - '"categories": ["{categories}" ], ', + '"categories": {categories}, ', '"classification_instructions": ["{classification_instructions}"]}}' """ @@ -54,16 +56,16 @@ You are a text classification engine that analyzes text data and assigns categor ### Task Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. ### Format -The input text is in the variable text_field. Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination. Classification instructions may be included to improve the classification accuracy. +The input text is in the variable text_field. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. ### Constraint DO NOT include anything other than the JSON array in your response. ### Example Here is the chat example between human and assistant, inside XML tags. -User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],"categories": ["Customer Service, Satisfaction, Sales, Product"], "classification_instructions": ["classify the text based on the feedback provided by customer"]}} -Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"categories": ["Customer Service"]}} -User:{{"input_text": ["bad service, slow to bring the food"],"categories": ["Food Quality, Experience, Price" ], "classification_instructions": []}} -Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"categories": ["Customer Service"]}}{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"categories": ["Experience""]}} +User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"}},{{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"}},{{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"}},{{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}}], "classification_instructions": ["classify the text based on the feedback provided by customer"]}} +Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}} +User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}} +Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}} ### Memory Here is the chat histories between human and assistant, inside XML tags. @@ -71,6 +73,6 @@ Here is the chat histories between human and assistant, inside ### User Input -{{"input_text" : ["{input_text}"], "categories" : ["{categories}"],"classification_instruction" : ["{classification_instructions}"]}} +{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}} ### Assistant Output -""" \ No newline at end of file +""" diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 9e5cc0c889..2c4a2257f5 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,7 +1,7 @@ import os from typing import Optional, cast -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor +from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -53,7 +53,7 @@ class TemplateTransformNode(BaseNode): # Run code try: result = CodeExecutor.execute_workflow_code_template( - language='jinja2', + language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables ) diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 97fbe8a999..98b28ac4f1 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -7,7 +7,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class ToolEntity(BaseModel): provider_id: str - provider_type: Literal['builtin', 'api'] + provider_type: Literal['builtin', 'api', 'workflow'] provider_name: str # redundancy tool_name: str tool_label: str # redundancy diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index d183dbe17b..2a472fc8d2 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,13 +1,14 @@ from os import path -from typing import cast +from typing import Optional, cast from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod, FileType, FileVar -from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.tool.tool import Tool from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.tool.entities import ToolNodeData @@ -35,21 +36,23 @@ class ToolNode(BaseNode): 'provider_id': node_data.provider_id } - # get parameters - parameters = self._generate_parameters(variable_pool, node_data) # get tool runtime try: - self.app_id - tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data) + tool_runtime = ToolManager.get_workflow_tool_runtime( + self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from + ) except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters, + inputs={}, metadata={ NodeRunMetadataKey.TOOL_INFO: tool_info }, error=f'Failed to get tool runtime: {str(e)}' ) + + # get parameters + parameters = self._generate_parameters(variable_pool, node_data, tool_runtime) try: messages = ToolEngine.workflow_invoke( @@ -57,7 +60,8 @@ class ToolNode(BaseNode): tool_parameters=parameters, user_id=self.user_id, workflow_id=self.workflow_id, - workflow_tool_callback=DifyWorkflowCallbackHandler() + workflow_tool_callback=DifyWorkflowCallbackHandler(), + workflow_call_depth=self.workflow_call_depth, ) except Exception as e: return NodeRunResult( @@ -84,19 +88,32 @@ class ToolNode(BaseNode): inputs=parameters ) - def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict: + def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData, tool_runtime: Tool) -> dict: """ Generate parameters """ + tool_parameters = tool_runtime.get_all_runtime_parameters() + + def fetch_parameter(name: str) -> Optional[ToolParameter]: + return next((parameter for parameter in tool_parameters if parameter.name == name), None) + result = {} for parameter_name in node_data.tool_parameters: - input = node_data.tool_parameters[parameter_name] - if input.type == 'mixed': - result[parameter_name] = self._format_variable_template(input.value, variable_pool) - elif input.type == 'variable': - result[parameter_name] = variable_pool.get_variable_value(input.value) - elif input.type == 'constant': - result[parameter_name] = input.value + parameter = fetch_parameter(parameter_name) + if not parameter: + continue + if parameter.type == ToolParameter.ToolParameterType.FILE: + result[parameter_name] = [ + v.to_dict() for v in self._fetch_files(variable_pool) + ] + else: + input = node_data.tool_parameters[parameter_name] + if input.type == 'mixed': + result[parameter_name] = self._format_variable_template(input.value, variable_pool) + elif input.type == 'variable': + result[parameter_name] = variable_pool.get_variable_value(input.value) + elif input.type == 'constant': + result[parameter_name] = input.value return result @@ -110,6 +127,13 @@ class ToolNode(BaseNode): inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector) return template_parser.format(inputs) + + def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: + files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value]) + if not files: + return [] + + return files def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]: """ @@ -174,11 +198,12 @@ class ToolNode(BaseNode): """ Extract tool response text """ - return ''.join([ - f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else - f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else '' + return '\n'.join([ + f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else + f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else '' for message in tool_response ]) + @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: diff --git a/api/core/workflow/nodes/variable_aggregator/__init__.py b/api/core/workflow/nodes/variable_aggregator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py new file mode 100644 index 0000000000..d38e934451 --- /dev/null +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -0,0 +1,33 @@ + + +from typing import Literal, Optional + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class AdvancedSettings(BaseModel): + """ + Advanced setting. + """ + group_enabled: bool + + class Group(BaseModel): + """ + Group. + """ + output_type: Literal['string', 'number', 'array', 'object'] + variables: list[list[str]] + group_name: str + + groups: list[Group] + +class VariableAssignerNodeData(BaseNodeData): + """ + Knowledge retrieval Node Data. + """ + type: str = 'variable-assigner' + output_type: str + variables: list[list[str]] + advanced_settings: Optional[AdvancedSettings] \ No newline at end of file diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py new file mode 100644 index 0000000000..63ce790625 --- /dev/null +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -0,0 +1,54 @@ +from typing import cast + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class VariableAggregatorNode(BaseNode): + _node_data_cls = VariableAssignerNodeData + _node_type = NodeType.VARIABLE_AGGREGATOR + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + node_data = cast(VariableAssignerNodeData, self.node_data) + # Get variables + outputs = {} + inputs = {} + + if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled: + for variable in node_data.variables: + value = variable_pool.get_variable_value(variable) + + if value is not None: + outputs = { + "output": value + } + + inputs = { + '.'.join(variable[1:]): value + } + break + else: + for group in node_data.advanced_settings.groups: + for variable in group.variables: + value = variable_pool.get_variable_value(variable) + + if value is not None: + outputs[group.group_name] = { + 'output': value + } + inputs['.'.join(variable[1:])] = value + break + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs=outputs, + inputs=inputs + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + return {} diff --git a/api/core/workflow/nodes/variable_assigner/entities.py b/api/core/workflow/nodes/variable_assigner/entities.py deleted file mode 100644 index 035618bd66..0000000000 --- a/api/core/workflow/nodes/variable_assigner/entities.py +++ /dev/null @@ -1,12 +0,0 @@ - - -from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class VariableAssignerNodeData(BaseNodeData): - """ - Knowledge retrieval Node Data. - """ - type: str = 'variable-assigner' - output_type: str - variables: list[list[str]] diff --git a/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py b/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py deleted file mode 100644 index d0a1c9789c..0000000000 --- a/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import cast - -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.variable_assigner.entities import VariableAssignerNodeData -from models.workflow import WorkflowNodeExecutionStatus - - -class VariableAssignerNode(BaseNode): - _node_data_cls = VariableAssignerNodeData - _node_type = NodeType.VARIABLE_ASSIGNER - - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - node_data: VariableAssignerNodeData = cast(self._node_data_cls, self.node_data) - # Get variables - outputs = {} - inputs = {} - for variable in node_data.variables: - value = variable_pool.get_variable_value(variable) - - if value is not None: - outputs = { - "output": value - } - - inputs = { - '.'.join(variable[1:]): value - } - break - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - inputs=inputs - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - return {} diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 9390ffa2a4..afd19abfde 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -2,8 +2,11 @@ import logging import time from typing import Optional, cast +from flask import current_app + from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.entities.app_invoke_entities import InvokeFrom from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType @@ -11,19 +14,22 @@ from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base_node import BaseNode, UserFrom +from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.iteration.entities import IterationState +from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode from core.workflow.nodes.start.start_node import StartNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode +from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode from extensions.ext_database import db from models.workflow import ( Workflow, @@ -42,7 +48,10 @@ node_classes = { NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, NodeType.HTTP_REQUEST: HttpRequestNode, NodeType.TOOL: ToolNode, - NodeType.VARIABLE_ASSIGNER: VariableAssignerNode, + NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, + NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, + NodeType.ITERATION: IterationNode, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode } logger = logging.getLogger(__name__) @@ -81,18 +90,20 @@ class WorkflowEngineManager: def run_workflow(self, workflow: Workflow, user_id: str, user_from: UserFrom, + invoke_from: InvokeFrom, user_inputs: dict, system_inputs: Optional[dict] = None, - callbacks: list[BaseWorkflowCallback] = None) -> None: + callbacks: list[BaseWorkflowCallback] = None, + call_depth: Optional[int] = 0, + variable_pool: Optional[VariablePool] = None) -> None: """ - Run workflow :param workflow: Workflow instance :param user_id: user id :param user_from: user from :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files :param callbacks: workflow callbacks - :return: + :param call_depth: call depth """ # fetch workflow graph graph = workflow.graph_dict @@ -107,54 +118,185 @@ class WorkflowEngineManager: if not isinstance(graph.get('edges'), list): raise ValueError('edges in workflow graph must be a list') + + # init variable pool + if not variable_pool: + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=user_inputs + ) + + workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH") + if call_depth > workflow_call_max_depth: + raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) + + # init workflow run state + workflow_run_state = WorkflowRunState( + workflow=workflow, + start_at=time.perf_counter(), + variable_pool=variable_pool, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + workflow_call_depth=call_depth + ) # init workflow run if callbacks: for callback in callbacks: callback.on_workflow_run_started() - # init workflow run state - workflow_run_state = WorkflowRunState( + # run workflow + self._run_workflow( workflow=workflow, - start_at=time.perf_counter(), - variable_pool=VariablePool( - system_variables=system_inputs, - user_inputs=user_inputs - ), - user_id=user_id, - user_from=user_from + workflow_run_state=workflow_run_state, + callbacks=callbacks, ) + def _run_workflow(self, workflow: Workflow, + workflow_run_state: WorkflowRunState, + callbacks: list[BaseWorkflowCallback] = None, + start_at: Optional[str] = None, + end_at: Optional[str] = None) -> None: + """ + Run workflow + :param workflow: Workflow instance + :param user_id: user id + :param user_from: user from + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :param callbacks: workflow callbacks + :param call_depth: call depth + :param start_at: force specific start node + :param end_at: force specific end node + :return: + """ + graph = workflow.graph_dict + try: - predecessor_node = None + predecessor_node: BaseNode = None + current_iteration_node: BaseIterationNode = None has_entry_node = False + max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") + max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME") while True: # get next node, multiple target nodes in the future - next_node = self._get_next_node( + next_node = self._get_next_overall_node( workflow_run_state=workflow_run_state, graph=graph, predecessor_node=predecessor_node, - callbacks=callbacks + callbacks=callbacks, + start_at=start_at, + end_at=end_at ) + if not next_node: + # reached loop/iteration end or overall end + if current_iteration_node and workflow_run_state.current_iteration_state: + # reached loop/iteration end + # get next iteration + next_iteration = current_iteration_node.get_next_iteration( + variable_pool=workflow_run_state.variable_pool, + state=workflow_run_state.current_iteration_state + ) + self._workflow_iteration_next( + graph=graph, + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) + if isinstance(next_iteration, NodeRunResult): + if next_iteration.outputs: + for variable_key, variable_value in next_iteration.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + variable_pool=workflow_run_state.variable_pool, + node_id=current_iteration_node.node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) + self._workflow_iteration_completed( + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) + # iteration has ended + next_node = self._get_next_overall_node( + workflow_run_state=workflow_run_state, + graph=graph, + predecessor_node=current_iteration_node, + callbacks=callbacks, + start_at=start_at, + end_at=end_at + ) + current_iteration_node = None + workflow_run_state.current_iteration_state = None + # continue overall process + elif isinstance(next_iteration, str): + # move to next iteration + next_node_id = next_iteration + # get next id + next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks) + if not next_node: break # check is already ran - if next_node.node_id in [node_and_result.node.node_id - for node_and_result in workflow_run_state.workflow_nodes_and_results]: + if self._check_node_has_ran(workflow_run_state, next_node.node_id): predecessor_node = next_node continue has_entry_node = True - # max steps 30 reached - if len(workflow_run_state.workflow_nodes_and_results) > 30: - raise ValueError('Max steps 30 reached.') + # max steps reached + if workflow_run_state.workflow_node_steps > max_execution_steps: + raise ValueError('Max steps {} reached.'.format(max_execution_steps)) - # or max execution time 10min reached - if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=600): - raise ValueError('Max execution time 10min reached.') + # or max execution time reached + if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time): + raise ValueError('Max execution time {}s reached.'.format(max_execution_time)) + + # handle iteration nodes + if isinstance(next_node, BaseIterationNode): + current_iteration_node = next_node + workflow_run_state.current_iteration_state = next_node.run( + variable_pool=workflow_run_state.variable_pool + ) + self._workflow_iteration_started( + graph=graph, + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + predecessor_node_id=predecessor_node.node_id if predecessor_node else None, + callbacks=callbacks + ) + predecessor_node = next_node + # move to start node of iteration + next_node_id = next_node.get_next_iteration( + variable_pool=workflow_run_state.variable_pool, + state=workflow_run_state.current_iteration_state + ) + self._workflow_iteration_next( + graph=graph, + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) + if isinstance(next_node_id, NodeRunResult): + # iteration has ended + current_iteration_node.set_output( + variable_pool=workflow_run_state.variable_pool, + state=workflow_run_state.current_iteration_state + ) + self._workflow_iteration_completed( + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) + current_iteration_node = None + workflow_run_state.current_iteration_state = None + continue + else: + next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks) # run workflow, run multiple target nodes in the future self._run_workflow_node( @@ -231,7 +373,9 @@ class WorkflowEngineManager: workflow_id=workflow.id, user_id=user_id, user_from=UserFrom.ACCOUNT, - config=node_config + invoke_from=InvokeFrom.DEBUGGER, + config=node_config, + workflow_call_depth=0 ) try: @@ -247,49 +391,14 @@ class WorkflowEngineManager: except NotImplementedError: variable_mapping = {} - for variable_key, variable_selector in variable_mapping.items(): - if variable_key not in user_inputs: - raise ValueError(f'Variable key {variable_key} not found in user inputs.') - - # fetch variable node id from variable selector - variable_node_id = variable_selector[0] - variable_key_list = variable_selector[1:] - - # get value - value = user_inputs.get(variable_key) - - # temp fix for image type - if node_type == NodeType.LLM: - new_value = [] - if isinstance(value, list): - node_data = node_instance.node_data - node_data = cast(LLMNodeData, node_data) - - detail = node_data.vision.configs.detail if node_data.vision.configs else None - - for item in value: - if isinstance(item, dict) and 'type' in item and item['type'] == 'image': - transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) - file = FileVar( - tenant_id=workflow.tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=item.get( - 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), - ) - new_value.append(file) - - if new_value: - value = new_value - - # append variable and value to variable pool - variable_pool.append_variable( - node_id=variable_node_id, - variable_key_list=variable_key_list, - value=value - ) + self._mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_instance=node_instance + ) + # run node node_run_result = node_instance.run( variable_pool=variable_pool @@ -307,6 +416,126 @@ class WorkflowEngineManager: return node_instance, node_run_result + def single_step_run_iteration_workflow_node(self, workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict, + callbacks: list[BaseWorkflowCallback] = None, + ) -> None: + """ + Single iteration run workflow node + """ + # fetch node info from workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError('workflow graph not found') + + nodes = graph.get('nodes') + if not nodes: + raise ValueError('nodes not found in workflow graph') + + for node in nodes: + if node.get('id') == node_id: + if node.get('data', {}).get('type') in [ + NodeType.ITERATION.value, + NodeType.LOOP.value, + ]: + node_config = node + else: + raise ValueError('node id is not an iteration node') + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={} + ) + + # variable selector to variable mapping + iteration_nested_nodes = [ + node for node in nodes + if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id + ] + iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes] + + if not iteration_nested_nodes: + raise ValueError('iteration has no nested nodes') + + # init workflow run + if callbacks: + for callback in callbacks: + callback.on_workflow_run_started() + + for node_config in iteration_nested_nodes: + # mapping user inputs to variable pool + node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) + except NotImplementedError: + variable_mapping = {} + + # remove iteration variables + variable_mapping = { + f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() + if value[0] != node_id + } + + # remove variable out from iteration + variable_mapping = { + key: value for key, value in variable_mapping.items() + if value[0] not in iteration_nested_node_ids + } + + # append variables to variable pool + node_instance = node_cls( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + config=node_config, + callbacks=callbacks, + workflow_call_depth=0 + ) + + self._mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_instance=node_instance + ) + + # fetch end node of iteration + end_node_id = None + for edge in graph.get('edges'): + if edge.get('source') == node_id: + end_node_id = edge.get('target') + break + + if not end_node_id: + raise ValueError('end node of iteration not found') + + # init workflow run state + workflow_run_state = WorkflowRunState( + workflow=workflow, + start_at=time.perf_counter(), + variable_pool=variable_pool, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + workflow_call_depth=0 + ) + + # run workflow + self._run_workflow( + workflow=workflow, + workflow_run_state=workflow_run_state, + callbacks=callbacks, + start_at=node_id, + end_at=end_node_id + ) + def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Workflow run success @@ -332,10 +561,96 @@ class WorkflowEngineManager: error=error ) - def _get_next_node(self, workflow_run_state: WorkflowRunState, + def _workflow_iteration_started(self, graph: dict, + current_iteration_node: BaseIterationNode, + workflow_run_state: WorkflowRunState, + predecessor_node_id: Optional[str] = None, + callbacks: list[BaseWorkflowCallback] = None) -> None: + """ + Workflow iteration started + :param current_iteration_node: current iteration node + :param workflow_run_state: workflow run state + :param callbacks: workflow callbacks + :return: + """ + # get nested nodes + iteration_nested_nodes = [ + node for node in graph.get('nodes') + if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id + ] + + if not iteration_nested_nodes: + raise ValueError('iteration has no nested nodes') + + if callbacks: + if isinstance(workflow_run_state.current_iteration_state, IterationState): + for callback in callbacks: + callback.on_workflow_iteration_started( + node_id=current_iteration_node.node_id, + node_type=NodeType.ITERATION, + node_run_index=workflow_run_state.workflow_node_steps, + node_data=current_iteration_node.node_data, + inputs=workflow_run_state.current_iteration_state.inputs, + predecessor_node_id=predecessor_node_id, + metadata=workflow_run_state.current_iteration_state.metadata.dict() + ) + + # add steps + workflow_run_state.workflow_node_steps += 1 + + def _workflow_iteration_next(self, graph: dict, + current_iteration_node: BaseIterationNode, + workflow_run_state: WorkflowRunState, + callbacks: list[BaseWorkflowCallback] = None) -> None: + """ + Workflow iteration next + :param workflow_run_state: workflow run state + :return: + """ + if callbacks: + if isinstance(workflow_run_state.current_iteration_state, IterationState): + for callback in callbacks: + callback.on_workflow_iteration_next( + node_id=current_iteration_node.node_id, + node_type=NodeType.ITERATION, + index=workflow_run_state.current_iteration_state.index, + node_run_index=workflow_run_state.workflow_node_steps, + output=workflow_run_state.current_iteration_state.get_current_output() + ) + # clear ran nodes + workflow_run_state.workflow_node_runs = [ + node_run for node_run in workflow_run_state.workflow_node_runs + if node_run.iteration_node_id != current_iteration_node.node_id + ] + + # clear variables in current iteration + nodes = graph.get('nodes') + nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id] + + for node in nodes: + workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id')) + + def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode, + workflow_run_state: WorkflowRunState, + callbacks: list[BaseWorkflowCallback] = None) -> None: + if callbacks: + if isinstance(workflow_run_state.current_iteration_state, IterationState): + for callback in callbacks: + callback.on_workflow_iteration_completed( + node_id=current_iteration_node.node_id, + node_type=NodeType.ITERATION, + node_run_index=workflow_run_state.workflow_node_steps, + outputs={ + 'output': workflow_run_state.current_iteration_state.outputs + } + ) + + def _get_next_overall_node(self, workflow_run_state: WorkflowRunState, graph: dict, predecessor_node: Optional[BaseNode] = None, - callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]: + callbacks: list[BaseWorkflowCallback] = None, + start_at: Optional[str] = None, + end_at: Optional[str] = None) -> Optional[BaseNode]: """ Get next node multiple target nodes in the future. @@ -350,16 +665,26 @@ class WorkflowEngineManager: if not predecessor_node: for node_config in nodes: - if node_config.get('data', {}).get('type', '') == NodeType.START.value: - return StartNode( + node_cls = None + if start_at: + if node_config.get('id') == start_at: + node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) + else: + if node_config.get('data', {}).get('type', '') == NodeType.START.value: + node_cls = StartNode + if node_cls: + return node_cls( tenant_id=workflow_run_state.tenant_id, app_id=workflow_run_state.app_id, workflow_id=workflow_run_state.workflow_id, user_id=workflow_run_state.user_id, user_from=workflow_run_state.user_from, + invoke_from=workflow_run_state.invoke_from, config=node_config, - callbacks=callbacks + callbacks=callbacks, + workflow_call_depth=workflow_run_state.workflow_call_depth ) + else: edges = graph.get('edges') source_node_id = predecessor_node.node_id @@ -386,6 +711,9 @@ class WorkflowEngineManager: target_node_id = outgoing_edge.get('target') + if end_at and target_node_id == end_at: + return None + # fetch target node from target node id target_node_config = None for node in nodes: @@ -405,9 +733,40 @@ class WorkflowEngineManager: workflow_id=workflow_run_state.workflow_id, user_id=workflow_run_state.user_id, user_from=workflow_run_state.user_from, + invoke_from=workflow_run_state.invoke_from, config=target_node_config, - callbacks=callbacks + callbacks=callbacks, + workflow_call_depth=workflow_run_state.workflow_call_depth ) + + def _get_node(self, workflow_run_state: WorkflowRunState, + graph: dict, + node_id: str, + callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]: + """ + Get node from graph by node id + """ + nodes = graph.get('nodes') + if not nodes: + return None + + for node_config in nodes: + if node_config.get('id') == node_id: + node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + return node_cls( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, + invoke_from=workflow_run_state.invoke_from, + config=node_config, + callbacks=callbacks, + workflow_call_depth=workflow_run_state.workflow_call_depth + ) + + return None def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ @@ -418,6 +777,15 @@ class WorkflowEngineManager: """ return time.perf_counter() - start_at > max_execution_time + def _check_node_has_ran(self, workflow_run_state: WorkflowRunState, node_id: str) -> bool: + """ + Check node has ran + """ + return bool([ + node_and_result for node_and_result in workflow_run_state.workflow_node_runs + if node_and_result.node_id == node_id + ]) + def _run_workflow_node(self, workflow_run_state: WorkflowRunState, node: BaseNode, predecessor_node: Optional[BaseNode] = None, @@ -428,7 +796,7 @@ class WorkflowEngineManager: node_id=node.node_id, node_type=node.node_type, node_data=node.node_data, - node_run_index=len(workflow_run_state.workflow_nodes_and_results) + 1, + node_run_index=workflow_run_state.workflow_node_steps, predecessor_node_id=predecessor_node.node_id if predecessor_node else None ) @@ -442,6 +810,16 @@ class WorkflowEngineManager: # add to workflow_nodes_and_results workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) + # add steps + workflow_run_state.workflow_node_steps += 1 + + # mark node as running + if workflow_run_state.current_iteration_state: + workflow_run_state.workflow_node_runs.append(WorkflowRunState.NodeRun( + node_id=node.node_id, + iteration_node_id=workflow_run_state.current_iteration_state.iteration_node_id + )) + try: # run node, result must have inputs, process_data, outputs, execution_metadata node_run_result = node.run( @@ -561,3 +939,53 @@ class WorkflowEngineManager: new_value[key] = new_val return new_value + + def _mapping_user_inputs_to_variable_pool(self, + variable_mapping: dict, + user_inputs: dict, + variable_pool: VariablePool, + tenant_id: str, + node_instance: BaseNode): + for variable_key, variable_selector in variable_mapping.items(): + if variable_key not in user_inputs: + raise ValueError(f'Variable key {variable_key} not found in user inputs.') + + # fetch variable node id from variable selector + variable_node_id = variable_selector[0] + variable_key_list = variable_selector[1:] + + # get value + value = user_inputs.get(variable_key) + + # temp fix for image type + if node_instance.node_type == NodeType.LLM: + new_value = [] + if isinstance(value, list): + node_data = node_instance.node_data + node_data = cast(LLMNodeData, node_data) + + detail = node_data.vision.configs.detail if node_data.vision.configs else None + + for item in value: + if isinstance(item, dict) and 'type' in item and item['type'] == 'image': + transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) + file = FileVar( + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=transfer_method, + url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=item.get( + 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), + ) + new_value.append(file) + + if new_value: + value = new_value + + # append variable and value to variable pool + variable_pool.append_variable( + node_id=variable_node_id, + variable_key_list=variable_key_list, + value=value + ) \ No newline at end of file diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 688b80aa8c..ceb50a252b 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle from .deduct_quota_when_messaeg_created import handle from .delete_installed_app_when_app_deleted import handle from .delete_tool_parameters_cache_when_sync_draft_workflow import handle +from .delete_workflow_as_tool_when_app_deleted import handle from .update_app_dataset_join_when_app_model_config_updated import handle from .update_app_dataset_join_when_app_published_workflow_updated import handle from .update_provider_last_used_at_when_messaeg_created import handle diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 2a127d903e..1f6da34ee2 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -10,18 +10,22 @@ def handle(sender, **kwargs): app = sender for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []): if node_data.get('data', {}).get('type') == NodeType.TOOL.value: - tool_entity = ToolEntity(**node_data["data"]) - tool_runtime = ToolManager.get_tool_runtime( - provider_type=tool_entity.provider_type, - provider_name=tool_entity.provider_id, - tool_name=tool_entity.tool_name, - tenant_id=app.tenant_id, - ) - manager = ToolParameterConfigurationManager( - tenant_id=app.tenant_id, - tool_runtime=tool_runtime, - provider_name=tool_entity.provider_name, - provider_type=tool_entity.provider_type, - identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}' - ) - manager.delete_tool_parameters_cache() + try: + tool_entity = ToolEntity(**node_data["data"]) + tool_runtime = ToolManager.get_tool_runtime( + provider_type=tool_entity.provider_type, + provider_id=tool_entity.provider_id, + tool_name=tool_entity.tool_name, + tenant_id=app.tenant_id, + ) + manager = ToolParameterConfigurationManager( + tenant_id=app.tenant_id, + tool_runtime=tool_runtime, + provider_name=tool_entity.provider_name, + provider_type=tool_entity.provider_type, + identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}' + ) + manager.delete_tool_parameters_cache() + except: + # tool dose not exist + pass diff --git a/api/events/event_handlers/delete_workflow_as_tool_when_app_deleted.py b/api/events/event_handlers/delete_workflow_as_tool_when_app_deleted.py new file mode 100644 index 0000000000..0c56688ff6 --- /dev/null +++ b/api/events/event_handlers/delete_workflow_as_tool_when_app_deleted.py @@ -0,0 +1,14 @@ +from events.app_event import app_was_deleted +from extensions.ext_database import db +from models.tools import WorkflowToolProvider + + +@app_was_deleted.connect +def handle(sender, **kwargs): + app = sender + workflow_tools = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.app_id == app.id + ).all() + for workflow_tool in workflow_tools: + db.session.delete(workflow_tool) + db.session.commit() diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index d2c6e32dfd..ec3a5cc112 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -1,3 +1,4 @@ +import logging from typing import Optional import resend @@ -16,7 +17,7 @@ class Mail: if app.config.get('MAIL_TYPE'): if app.config.get('MAIL_DEFAULT_SEND_FROM'): self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM') - + if app.config.get('MAIL_TYPE') == 'resend': api_key = app.config.get('RESEND_API_KEY') if not api_key: @@ -32,16 +33,22 @@ class Mail: from libs.smtp import SMTPClient if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'): raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type') + if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'): + raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS') self._client = SMTPClient( server=app.config.get('SMTP_SERVER'), port=app.config.get('SMTP_PORT'), username=app.config.get('SMTP_USERNAME'), password=app.config.get('SMTP_PASSWORD'), _from=app.config.get('MAIL_DEFAULT_SEND_FROM'), - use_tls=app.config.get('SMTP_USE_TLS') + use_tls=app.config.get('SMTP_USE_TLS'), + opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS') ) else: raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE'))) + else: + logging.warning('MAIL_TYPE is not set') + def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): if not self._client: diff --git a/api/extensions/storage/azure_storage.py b/api/extensions/storage/azure_storage.py index 01de8bab94..b9809de640 100644 --- a/api/extensions/storage/azure_storage.py +++ b/api/extensions/storage/azure_storage.py @@ -5,38 +5,39 @@ from datetime import datetime, timedelta, timezone from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas from flask import Flask +from extensions.ext_redis import redis_client from extensions.storage.base_storage import BaseStorage class AzureStorage(BaseStorage): """Implementation for azure storage. """ + def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('AZURE_STORAGE_CONTAINER_NAME') - sas_token = generate_account_sas( - account_name=app_config.get('AZURE_BLOB_ACCOUNT_NAME'), - account_key=app_config.get('AZURE_BLOB_ACCOUNT_KEY'), - resource_types=ResourceTypes(service=True, container=True, object=True), - permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1) - ) - self.client = BlobServiceClient(account_url=app_config.get('AZURE_BLOB_ACCOUNT_URL'), - credential=sas_token) + self.bucket_name = app_config.get('AZURE_BLOB_CONTAINER_NAME') + self.account_url = app_config.get('AZURE_BLOB_ACCOUNT_URL') + self.account_name = app_config.get('AZURE_BLOB_ACCOUNT_NAME') + self.account_key = app_config.get('AZURE_BLOB_ACCOUNT_KEY') + def save(self, filename, data): - blob_container = self.client.get_container_client(container=self.bucket_name) + client = self._sync_client() + blob_container = client.get_container_client(container=self.bucket_name) blob_container.upload_blob(filename, data) def load_once(self, filename: str) -> bytes: - blob = self.client.get_container_client(container=self.bucket_name) + client = self._sync_client() + blob = client.get_container_client(container=self.bucket_name) blob = blob.get_blob_client(blob=filename) data = blob.download_blob().readall() return data def load_stream(self, filename: str) -> Generator: + client = self._sync_client() + def generate(filename: str = filename) -> Generator: - blob = self.client.get_blob_client(container=self.bucket_name, blob=filename) + blob = client.get_blob_client(container=self.bucket_name, blob=filename) with closing(blob.download_blob()) as blob_stream: while chunk := blob_stream.readall(4096): yield chunk @@ -44,15 +45,37 @@ class AzureStorage(BaseStorage): return generate() def download(self, filename, target_filepath): - blob = self.client.get_blob_client(container=self.bucket_name, blob=filename) + client = self._sync_client() + + blob = client.get_blob_client(container=self.bucket_name, blob=filename) with open(target_filepath, "wb") as my_blob: blob_data = blob.download_blob() blob_data.readinto(my_blob) def exists(self, filename): - blob = self.client.get_blob_client(container=self.bucket_name, blob=filename) + client = self._sync_client() + + blob = client.get_blob_client(container=self.bucket_name, blob=filename) return blob.exists() def delete(self, filename): - blob_container = self.client.get_container_client(container=self.bucket_name) - blob_container.delete_blob(filename) \ No newline at end of file + client = self._sync_client() + + blob_container = client.get_container_client(container=self.bucket_name) + blob_container.delete_blob(filename) + + def _sync_client(self): + cache_key = 'azure_blob_sas_token_{}_{}'.format(self.account_name, self.account_key) + cache_result = redis_client.get(cache_key) + if cache_result is not None: + sas_token = cache_result.decode('utf-8') + else: + sas_token = generate_account_sas( + account_name=self.account_name, + account_key=self.account_key, + resource_types=ResourceTypes(service=True, container=True, object=True), + permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), + expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1) + ) + redis_client.set(cache_key, sas_token, ex=3000) + return BlobServiceClient(account_url=self.account_url, credential=sas_token) diff --git a/api/extensions/storage/google_storage.py b/api/extensions/storage/google_storage.py index f6c69eb0ae..97004fddab 100644 --- a/api/extensions/storage/google_storage.py +++ b/api/extensions/storage/google_storage.py @@ -1,4 +1,5 @@ import base64 +import io from collections.abc import Generator from contextlib import closing @@ -15,14 +16,19 @@ class GoogleStorage(BaseStorage): super().__init__(app) app_config = self.app.config self.bucket_name = app_config.get('GOOGLE_STORAGE_BUCKET_NAME') - service_account_json = base64.b64decode(app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')).decode( - 'utf-8') - self.client = GoogleCloudStorage.Client().from_service_account_json(service_account_json) + service_account_json_str = app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64') + # if service_account_json_str is empty, use Application Default Credentials + if service_account_json_str: + service_account_json = base64.b64decode(service_account_json_str).decode('utf-8') + self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_json) + else: + self.client = GoogleCloudStorage.Client() def save(self, filename, data): bucket = self.client.get_bucket(self.bucket_name) blob = bucket.blob(filename) - blob.upload_from_file(data) + with io.BytesIO(data) as stream: + blob.upload_from_file(stream) def load_once(self, filename: str) -> bytes: bucket = self.client.get_bucket(self.bucket_name) diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index c7cfdd7939..212c3e7f17 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -113,6 +113,7 @@ site_fields = { 'customize_domain': fields.String, 'copyright': fields.String, 'privacy_policy': fields.String, + 'custom_disclaimer': fields.String, 'customize_token_strategy': fields.String, 'prompt_public': fields.Boolean, 'app_base_url': fields.String, @@ -146,6 +147,7 @@ app_site_fields = { 'customize_domain': fields.String, 'copyright': fields.String, 'privacy_policy': fields.String, + 'custom_disclaimer': fields.String, 'customize_token_strategy': fields.String, 'prompt_public': fields.Boolean } diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 94d905eafe..e8215255b3 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -8,6 +8,7 @@ document_fields = { 'position': fields.Integer, 'data_source_type': fields.String, 'data_source_info': fields.Raw(attribute='data_source_info_dict'), + 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), 'dataset_process_rule_id': fields.String, 'name': fields.String, 'created_from': fields.String, @@ -31,6 +32,7 @@ document_with_segments_fields = { 'position': fields.Integer, 'data_source_type': fields.String, 'data_source_info': fields.Raw(attribute='data_source_info_dict'), + 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), 'dataset_process_rule_id': fields.String, 'name': fields.String, 'created_from': fields.String, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 9919a440e8..54d7ed55f8 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -7,8 +7,10 @@ workflow_fields = { 'id': fields.String, 'graph': fields.Raw(attribute='graph_dict'), 'features': fields.Raw(attribute='features_dict'), + 'hash': fields.String(attribute='unique_hash'), 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), 'created_at': TimestampField, 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), - 'updated_at': TimestampField + 'updated_at': TimestampField, + 'tool_published': fields.Boolean, } diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index c22546f602..9856875c16 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -48,7 +48,7 @@ class PKCS1OAEP_Cipher: `Crypto.Hash.SHA1` is used. mgfunc : callable A mask generation function that accepts two parameters: a string to - use as seed, and the lenth of the mask to generate, in bytes. + use as seed, and the length of the mask to generate, in bytes. If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice). label : bytes/bytearray/memoryview A label to apply to this particular encryption. If not specified, @@ -218,7 +218,7 @@ def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None): :param mgfunc: A mask generation function that accepts two parameters: a string to - use as seed, and the lenth of the mask to generate, in bytes. + use as seed, and the length of the mask to generate, in bytes. If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice). :type mgfunc: callable diff --git a/api/libs/helper.py b/api/libs/helper.py index f9cf590b7a..f4be9c5531 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -46,7 +46,13 @@ def uuid_value(value): error = ('{value} is not a valid uuid.' .format(value=value)) raise ValueError(error) - + +def alphanumeric(value: str): + # check if the value is alphanumeric and underlined + if re.match(r'^[a-zA-Z0-9_]+$', value): + return value + + raise ValueError(f'{value} is not a valid alphanumeric value') def timestamp_value(timestamp): try: diff --git a/api/libs/smtp.py b/api/libs/smtp.py index 30a795bd70..bf3a1a92e9 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -1,27 +1,50 @@ +import logging import smtplib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText class SMTPClient: - def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False): + def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False): self.server = server self.port = port self._from = _from self.username = username self.password = password - self._use_tls = use_tls + self.use_tls = use_tls + self.opportunistic_tls = opportunistic_tls def send(self, mail: dict): - smtp = smtplib.SMTP(self.server, self.port) - if self._use_tls: - smtp.starttls() - if self.username and self.password: - smtp.login(self.username, self.password) - msg = MIMEMultipart() - msg['Subject'] = mail['subject'] - msg['From'] = self._from - msg['To'] = mail['to'] - msg.attach(MIMEText(mail['html'], 'html')) - smtp.sendmail(self.username, mail['to'], msg.as_string()) - smtp.quit() + smtp = None + try: + if self.use_tls: + if self.opportunistic_tls: + smtp = smtplib.SMTP(self.server, self.port, timeout=10) + smtp.starttls() + else: + smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) + else: + smtp = smtplib.SMTP(self.server, self.port, timeout=10) + + if self.username and self.password: + smtp.login(self.username, self.password) + + msg = MIMEMultipart() + msg['Subject'] = mail['subject'] + msg['From'] = self._from + msg['To'] = mail['to'] + msg.attach(MIMEText(mail['html'], 'html')) + + smtp.sendmail(self._from, mail['to'], msg.as_string()) + except smtplib.SMTPException as e: + logging.error(f"SMTP error occurred: {str(e)}") + raise + except TimeoutError as e: + logging.error(f"Timeout occurred while sending email: {str(e)}") + raise + except Exception as e: + logging.error(f"Unexpected error occurred while sending email: {str(e)}") + raise + finally: + if smtp: + smtp.quit() diff --git a/api/migrations/env.py b/api/migrations/env.py index 18485c1885..ad3a122c04 100644 --- a/api/migrations/env.py +++ b/api/migrations/env.py @@ -110,3 +110,4 @@ if context.is_offline_mode(): run_migrations_offline() else: run_migrations_online() + diff --git a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py new file mode 100644 index 0000000000..0fba6a87eb --- /dev/null +++ b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py @@ -0,0 +1,32 @@ +"""add workflow tool label and tool bindings idx + +Revision ID: 03f98355ba0e +Revises: 9e98fbaffb88 +Create Date: 2024-05-25 07:17:00.539125 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '03f98355ba0e' +down_revision = '9e98fbaffb88' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('tool_label_bindings', schema=None) as batch_op: + batch_op.create_unique_constraint('unique_tool_label_bind', ['tool_id', 'label_name']) + + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('label', sa.String(length=255), server_default='', nullable=False)) + +def downgrade(): + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.drop_column('label') + + with op.batch_alter_table('tool_label_bindings', schema=None) as batch_op: + batch_op.drop_constraint('unique_tool_label_bind', type_='unique') diff --git a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py new file mode 100644 index 0000000000..db3119badf --- /dev/null +++ b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py @@ -0,0 +1,42 @@ +"""add tool label bings + +Revision ID: 3b18fea55204 +Revises: 7bdef072e63a +Create Date: 2024-05-14 09:27:18.857890 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '3b18fea55204' +down_revision = '7bdef072e63a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_label_bindings', + sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tool_id', sa.String(length=64), nullable=False), + sa.Column('tool_type', sa.String(length=40), nullable=False), + sa.Column('label_name', sa.String(length=40), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_label_bind_pkey') + ) + + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('privacy_policy', sa.String(length=255), server_default='', nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.drop_column('privacy_policy') + + op.drop_table('tool_label_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py b/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py new file mode 100644 index 0000000000..b37928d3c0 --- /dev/null +++ b/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py @@ -0,0 +1,39 @@ +"""modify default model name length + +Revision ID: 47cc7df8c4f3 +Revises: 3c7cac9521c6 +Create Date: 2024-05-10 09:48:09.046298 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '47cc7df8c4f3' +down_revision = '3c7cac9521c6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.alter_column('model_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.alter_column('model_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py new file mode 100644 index 0000000000..67d7b9fbf5 --- /dev/null +++ b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py @@ -0,0 +1,126 @@ +"""add load balancing + +Revision ID: 4e99a8df00ff +Revises: 47cc7df8c4f3 +Create Date: 2024-05-10 12:08:09.812736 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '4e99a8df00ff' +down_revision = '64a70a7aab8b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('load_balancing_model_configs', + sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('model_name', sa.String(length=255), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('encrypted_config', sa.Text(), nullable=True), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey') + ) + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.create_index('load_balancing_model_config_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False) + + op.create_table('provider_model_settings', + sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('model_name', sa.String(length=255), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('load_balancing_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_model_setting_pkey') + ) + with op.batch_alter_table('provider_model_settings', schema=None) as batch_op: + batch_op.create_index('provider_model_setting_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False) + + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('provider_orders', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('provider_orders', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('provider_model_settings', schema=None) as batch_op: + batch_op.drop_index('provider_model_setting_tenant_provider_model_idx') + + op.drop_table('provider_model_settings') + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.drop_index('load_balancing_model_config_tenant_provider_model_idx') + + op.drop_table('load_balancing_model_configs') + # ### end Alembic commands ### diff --git a/api/migrations/versions/5fda94355fce_custom_disclaimer.py b/api/migrations/versions/5fda94355fce_custom_disclaimer.py new file mode 100644 index 0000000000..73bcdc4500 --- /dev/null +++ b/api/migrations/versions/5fda94355fce_custom_disclaimer.py @@ -0,0 +1,45 @@ +"""Custom Disclaimer + +Revision ID: 5fda94355fce +Revises: 47cc7df8c4f3 +Create Date: 2024-05-10 20:04:45.806549 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '5fda94355fce' +down_revision = '47cc7df8c4f3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('custom_disclaimer', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.add_column(sa.Column('custom_disclaimer', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('custom_disclaimer', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.drop_column('custom_disclaimer') + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.drop_column('custom_disclaimer') + + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.drop_column('custom_disclaimer') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py b/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py new file mode 100644 index 0000000000..73242908f4 --- /dev/null +++ b/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py @@ -0,0 +1,32 @@ +"""add workflow run index + +Revision ID: 64a70a7aab8b +Revises: 03f98355ba0e +Create Date: 2024-05-28 12:32:00.276061 + +""" +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '64a70a7aab8b' +down_revision = '03f98355ba0e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.create_index('workflow_run_tenant_app_sequence_idx', ['tenant_id', 'app_id', 'sequence_number'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index('workflow_run_tenant_app_sequence_idx') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py new file mode 100644 index 0000000000..67b61e5c76 --- /dev/null +++ b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py @@ -0,0 +1,42 @@ +"""add workflow tool + +Revision ID: 7bdef072e63a +Revises: 5fda94355fce +Create Date: 2024-05-04 09:47:19.366961 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '7bdef072e63a' +down_revision = '5fda94355fce' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_workflow_providers', + sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=40), nullable=False), + sa.Column('icon', sa.String(length=255), nullable=False), + sa.Column('app_id', models.StringUUID(), nullable=False), + sa.Column('user_id', models.StringUUID(), nullable=False), + sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('parameter_configuration', sa.Text(), server_default='[]', nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'), + sa.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'), + sa.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id') + ) + # ### end Alembic commands ### + + +def downgrade(): + op.drop_table('tool_workflow_providers') + # ### end Alembic commands ### diff --git a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py new file mode 100644 index 0000000000..bfda7d619c --- /dev/null +++ b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py @@ -0,0 +1,26 @@ +"""add workflow tool version + +Revision ID: 9e98fbaffb88 +Revises: 3b18fea55204 +Create Date: 2024-05-21 10:25:40.434162 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '9e98fbaffb88' +down_revision = '3b18fea55204' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('version', sa.String(length=255), server_default='', nullable=False)) + +def downgrade(): + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.drop_column('version') diff --git a/api/models/dataset.py b/api/models/dataset.py index 01b068fa2a..7f98bbde15 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1,8 +1,15 @@ +import base64 +import hashlib +import hmac import json import logging +import os import pickle +import re +import time from json import JSONDecodeError +from flask import current_app from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB @@ -69,7 +76,8 @@ class Dataset(db.Model): @property def app_count(self): - return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id).scalar() + return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id, + App.id == AppDatasetJoin.app_id).scalar() @property def document_count(self): @@ -414,6 +422,26 @@ class DocumentSegment(db.Model): DocumentSegment.position == self.position + 1 ).first() + def get_sign_content(self): + pattern = r"/files/([a-f0-9\-]+)/image-preview" + text = self.content + match = re.search(pattern, text) + + if match: + upload_file_id = match.group(1) + nonce = os.urandom(16).hex() + timestamp = str(int(time.time())) + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = current_app.config['SECRET_KEY'].encode() + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + replacement = r"\g<0>?{params}".format(params=params) + text = re.sub(pattern, replacement, text) + return text + + class AppDatasetJoin(db.Model): __tablename__ = 'app_dataset_joins' diff --git a/api/models/model.py b/api/models/model.py index 59b88eb3b1..657db5a5c2 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -100,7 +100,7 @@ class App(db.Model): return None @property - def workflow(self): + def workflow(self) -> Optional['Workflow']: if self.workflow_id: from .workflow import Workflow return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() @@ -435,6 +435,7 @@ class RecommendedApp(db.Model): description = db.Column(db.JSON, nullable=False) copyright = db.Column(db.String(255), nullable=False) privacy_policy = db.Column(db.String(255), nullable=False) + custom_disclaimer = db.Column(db.String(255), nullable=True) category = db.Column(db.String(255), nullable=False) position = db.Column(db.Integer, nullable=False, default=0) is_listed = db.Column(db.Boolean, nullable=False, default=True) @@ -796,7 +797,7 @@ class Message(db.Model): if message_file.transfer_method == 'local_file': upload_file = (db.session.query(UploadFile) .filter( - UploadFile.id == message_file.related_id + UploadFile.id == message_file.upload_file_id ).first()) url = UploadFileParser.get_image_data( @@ -1042,6 +1043,7 @@ class Site(db.Model): default_language = db.Column(db.String(255), nullable=False) copyright = db.Column(db.String(255)) privacy_policy = db.Column(db.String(255)) + custom_disclaimer = db.Column(db.String(255), nullable=True) customize_domain = db.Column(db.String(255)) customize_token_strategy = db.Column(db.String(255), nullable=False) prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) diff --git a/api/models/provider.py b/api/models/provider.py index 413e8f9d67..4c14c33f09 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -47,7 +47,7 @@ class Provider(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) + provider_name = db.Column(db.String(255), nullable=False) provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) encrypted_config = db.Column(db.Text, nullable=True) is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) @@ -94,7 +94,7 @@ class ProviderModel(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) + provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) @@ -112,8 +112,8 @@ class TenantDefaultModel(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) - model_name = db.Column(db.String(40), nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -128,7 +128,7 @@ class TenantPreferredModelProvider(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) + provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -143,7 +143,7 @@ class ProviderOrder(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) + provider_name = db.Column(db.String(255), nullable=False) account_id = db.Column(StringUUID, nullable=False) payment_product_id = db.Column(db.String(191), nullable=False) payment_id = db.Column(db.String(191)) @@ -157,3 +157,46 @@ class ProviderOrder(db.Model): refunded_at = db.Column(db.DateTime) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class ProviderModelSetting(db.Model): + """ + Provider model settings for record the model enabled status and load balancing status. + """ + __tablename__ = 'provider_model_settings' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='provider_model_setting_pkey'), + db.Index('provider_model_setting_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + ) + + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + model_name = db.Column(db.String(255), nullable=False) + model_type = db.Column(db.String(40), nullable=False) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class LoadBalancingModelConfig(db.Model): + """ + Configurations for load balancing models. + """ + __tablename__ = 'load_balancing_model_configs' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey'), + db.Index('load_balancing_model_config_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + ) + + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + model_name = db.Column(db.String(255), nullable=False) + model_type = db.Column(db.String(40), nullable=False) + name = db.Column(db.String(255), nullable=False) + encrypted_config = db.Column(db.Text, nullable=True) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/models/tools.py b/api/models/tools.py index 8a133679e0..49212916ec 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -3,8 +3,8 @@ import json from sqlalchemy import ForeignKey from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_bundle import ApiBasedToolBundle -from core.tools.entities.tool_entities import ApiProviderSchemaType +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db from models import StringUUID from models.model import Account, App, Tenant @@ -107,6 +107,8 @@ class ApiToolProvider(db.Model): credentials_str = db.Column(db.Text, nullable=False) # privacy policy privacy_policy = db.Column(db.String(255), nullable=True) + # custom_disclaimer + custom_disclaimer = db.Column(db.String(255), nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -116,8 +118,8 @@ class ApiToolProvider(db.Model): return ApiProviderSchemaType.value_of(self.schema_type_str) @property - def tools(self) -> list[ApiBasedToolBundle]: - return [ApiBasedToolBundle(**tool) for tool in json.loads(self.tools_str)] + def tools(self) -> list[ApiToolBundle]: + return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] @property def credentials(self) -> dict: @@ -130,7 +132,84 @@ class ApiToolProvider(db.Model): @property def tenant(self) -> Tenant: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + +class ToolLabelBinding(db.Model): + """ + The table stores the labels for tools. + """ + __tablename__ = 'tool_label_bindings' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tool_label_bind_pkey'), + db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'), + ) + + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + # tool id + tool_id = db.Column(db.String(64), nullable=False) + # tool type + tool_type = db.Column(db.String(40), nullable=False) + # label name + label_name = db.Column(db.String(40), nullable=False) + +class WorkflowToolProvider(db.Model): + """ + The table stores the workflow providers. + """ + __tablename__ = 'tool_workflow_providers' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'), + db.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'), + db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'), + ) + + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + # name of the workflow provider + name = db.Column(db.String(40), nullable=False) + # label of the workflow provider + label = db.Column(db.String(255), nullable=False, server_default='') + # icon + icon = db.Column(db.String(255), nullable=False) + # app id of the workflow provider + app_id = db.Column(StringUUID, nullable=False) + # version of the workflow provider + version = db.Column(db.String(255), nullable=False, server_default='') + # who created this tool + user_id = db.Column(StringUUID, nullable=False) + # tenant id + tenant_id = db.Column(StringUUID, nullable=False) + # description of the provider + description = db.Column(db.Text, nullable=False) + # parameter configuration + parameter_configuration = db.Column(db.Text, nullable=False, server_default='[]') + # privacy policy + privacy_policy = db.Column(db.String(255), nullable=True, server_default='') + + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def schema_type(self) -> ApiProviderSchemaType: + return ApiProviderSchemaType.value_of(self.schema_type_str) + @property + def user(self) -> Account: + return db.session.query(Account).filter(Account.id == self.user_id).first() + + @property + def tenant(self) -> Tenant: + return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + + @property + def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: + return [ + WorkflowToolParameterConfiguration(**config) + for config in json.loads(self.parameter_configuration) + ] + + @property + def app(self) -> App: + return db.session.query(App).filter(App.id == self.app_id).first() + class ToolModelInvoke(db.Model): """ store the invoke logs from tool invoke diff --git a/api/models/workflow.py b/api/models/workflow.py index f261c67c77..d9bc784878 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,8 +2,8 @@ import json from enum import Enum from typing import Optional, Union -from core.tools.tool_manager import ToolManager from extensions.ext_database import db +from libs import helper from models import StringUUID from models.account import Account @@ -156,6 +156,27 @@ class Workflow(db.Model): return variables + @property + def unique_hash(self) -> str: + """ + Get hash of workflow. + + :return: hash + """ + entity = { + 'graph': self.graph_dict, + 'features': self.features_dict + } + + return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) + + @property + def tool_published(self) -> bool: + from models.tools import WorkflowToolProvider + return db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.app_id == self.app_id + ).first() is not None + class WorkflowRunTriggeredFrom(Enum): """ Workflow Run Triggered From Enum @@ -242,6 +263,7 @@ class WorkflowRun(db.Model): __table_args__ = ( db.PrimaryKeyConstraint('id', name='workflow_run_pkey'), db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'), + db.Index('workflow_run_tenant_app_sequence_idx', 'tenant_id', 'app_id', 'sequence_number'), ) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) @@ -457,6 +479,7 @@ class WorkflowNodeExecution(db.Model): @property def extras(self): + from core.tools.tool_manager import ToolManager extras = {} if self.execution_metadata_dict: from core.workflow.entities.node_entities import NodeType diff --git a/api/pyproject.toml b/api/pyproject.toml index 1e9f53cdb3..a8920139c6 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -7,12 +7,24 @@ exclude = [ line-length = 120 [tool.ruff.lint] -ignore-init-module-imports = true select = [ + "B", # flake8-bugbear rules "F", # pyflakes rules - "I001", # unsorted-imports - "I002", # missing-required-import + "I", # isort rules "UP", # pyupgrade rules + "E101", # mixed-spaces-and-tabs + "E111", # indentation-with-invalid-multiple + "E112", # no-indented-block + "E113", # unexpected-indentation + "E115", # no-indented-block-comment + "E116", # unexpected-indentation-comment + "E117", # over-indented + "RUF019", # unnecessary-key-check + "RUF100", # unused-noqa + "RUF101", # redirected-noqa + "S506", # unsafe-yaml-load + "W191", # tab-indentation + "W605", # invalid-escape-sequence ] ignore = [ "F403", # undefined-local-with-import-star @@ -21,6 +33,13 @@ ignore = [ "F841", # unused-variable "UP007", # non-pep604-annotation "UP032", # f-string + "B005", # strip-with-multi-characters + "B006", # mutable-argument-default + "B007", # unused-loop-control-variable + "B026", # star-arg-unpacking-after-keyword-arg + "B901", # return-in-generator + "B904", # raise-without-from-inside-except + "B905", # zip-without-explicit-strict ] [tool.ruff.lint.per-file-ignores] diff --git a/api/requirements.txt b/api/requirements.txt index 9d79afa4ec..1749b4a2df 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -9,8 +9,8 @@ flask-restful~=0.3.10 flask-cors~=4.0.0 gunicorn~=22.0.0 gevent~=23.9.1 -openai~=1.13.3 -tiktoken~=0.6.0 +openai~=1.29.0 +tiktoken~=0.7.0 psycopg2-binary~=2.9.6 pycryptodome==3.19.1 python-dotenv==1.0.0 @@ -26,7 +26,6 @@ sympy==1.12 jieba==0.42.1 celery~=5.3.6 redis[hiredis]~=5.0.3 -openpyxl==3.1.2 chardet~=5.1.0 python-docx~=1.1.0 pypdfium2~=4.17.0 @@ -42,7 +41,6 @@ google-api-python-client==2.90.0 google-auth==2.29.0 google-auth-httplib2==0.2.0 google-generativeai==0.5.0 -google-search-results==2.4.2 googleapis-common-protos==1.63.0 google-cloud-storage==2.16.0 replicate~=0.22.0 @@ -51,7 +49,7 @@ dashscope[tokenizer]~=1.17.0 huggingface_hub~=0.16.4 transformers~=4.35.0 tokenizers~=0.15.0 -pandas==1.5.3 +pandas[performance,excel]~=2.2.2 xinference-client==0.9.4 safetensors~=0.4.3 zhipuai==1.0.7 @@ -66,20 +64,24 @@ bs4~=0.0.1 markdown~=3.5.1 httpx[socks]~=0.24.1 matplotlib~=3.8.2 -yfinance~=0.2.35 +yfinance~=0.2.40 pydub~=0.25.1 gmpy2~=2.1.5 numexpr~=2.9.0 -duckduckgo-search==5.2.2 +duckduckgo-search~=6.1.5 arxiv==2.1.0 yarl~=1.9.4 twilio~=9.0.4 qrcode~=7.4.2 -azure-storage-blob==12.9.0 +azure-storage-blob==12.13.0 azure-identity==1.15.0 lxml==5.1.0 -xlrd~=2.0.1 pydantic~=1.10.0 pgvecto-rs==0.1.4 firecrawl-py==0.0.5 -oss2==2.15.0 +oss2==2.18.5 +pgvector==0.2.5 +pymysql==1.1.1 +tidb-vector==0.0.9 +google-cloud-aiplatform==1.49.0 +vanna[postgres,mysql,clickhouse,duckdb]==0.5.5 diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 8b00b28c4f..addcde44ed 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -31,7 +31,7 @@ class AppAnnotationService: if not app: raise NotFound("App not found") - if 'message_id' in args and args['message_id']: + if args.get('message_id'): message_id = str(args['message_id']) # get message info message = db.session.query(Message).filter( diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 185d9ba89f..f73a6dcbb6 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -75,6 +75,35 @@ class AppGenerateService: else: raise ValueError(f'Invalid app mode {app_model.mode}') + @classmethod + def generate_single_iteration(cls, app_model: App, + user: Union[Account, EndUser], + node_id: str, + args: Any, + streaming: bool = True): + if app_model.mode == AppMode.ADVANCED_CHAT.value: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator().single_iteration_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + stream=streaming + ) + elif app_model.mode == AppMode.WORKFLOW.value: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return WorkflowAppGenerator().single_iteration_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + stream=streaming + ) + else: + raise ValueError(f'Invalid app mode {app_model.mode}') + @classmethod def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ diff --git a/api/services/app_service.py b/api/services/app_service.py index 11073af09e..23c00740c8 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -47,10 +47,10 @@ class AppService: elif args['mode'] == 'channel': filters.append(App.mode == AppMode.CHANNEL.value) - if 'name' in args and args['name']: + if args.get('name'): name = args['name'][:30] filters.append(App.name.ilike(f'%{name}%')) - if 'tag_ids' in args and args['tag_ids']: + if args.get('tag_ids'): target_ids = TagService.get_target_ids_by_tag_ids('app', tenant_id, args['tag_ids']) @@ -196,6 +196,7 @@ class AppService: app_model=app, graph=workflow.get('graph'), features=workflow.get('features'), + unique_hash=None, account=account ) workflow_service.publish_workflow( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 17399c8ac8..06d3e9ec40 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -4,7 +4,7 @@ import logging import random import time import uuid -from typing import Optional, cast +from typing import Optional from flask import current_app from flask_login import current_user @@ -13,7 +13,6 @@ from sqlalchemy import func from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.models.document import Document as RAGDocument from events.dataset_event import dataset_was_deleted @@ -43,6 +42,7 @@ from services.vector_service import VectorService from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task +from tasks.disable_segment_from_index_task import disable_segment_from_index_task from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task @@ -450,6 +450,27 @@ class DocumentService: db.session.delete(document) db.session.commit() + @staticmethod + def rename_document(dataset_id: str, document_id: str, name: str) -> Document: + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise ValueError('Dataset not found.') + + document = DocumentService.get_document(dataset_id, document_id) + + if not document: + raise ValueError('Document not found.') + + if document.tenant_id != current_user.current_tenant_id: + raise ValueError('No permission.') + + document.name = name + + db.session.add(document) + db.session.commit() + + return document + @staticmethod def pause_document(document): if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]: @@ -568,7 +589,7 @@ class DocumentService: documents = [] batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) - if 'original_document_id' in document_data and document_data["original_document_id"]: + if document_data.get("original_document_id"): document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) documents.append(document) else: @@ -749,10 +770,10 @@ class DocumentService: if document.display_status != 'available': raise ValueError("Document is not available") # update document name - if 'name' in document_data and document_data['name']: + if document_data.get('name'): document.name = document_data['name'] # save process rule - if 'process_rule' in document_data and document_data['process_rule']: + if document_data.get('process_rule'): process_rule = document_data["process_rule"] if process_rule["mode"] == "custom": dataset_process_rule = DatasetProcessRule( @@ -772,7 +793,7 @@ class DocumentService: db.session.commit() document.dataset_process_rule_id = dataset_process_rule.id # update document data source - if 'data_source' in document_data and document_data['data_source']: + if document_data.get('data_source'): file_name = '' data_source_info = {} if document_data["data_source"]["type"] == "upload_file": @@ -870,7 +891,7 @@ class DocumentService: embedding_model.model ) dataset_collection_binding_id = dataset_collection_binding.id - if 'retrieval_model' in document_data and document_data['retrieval_model']: + if document_data.get('retrieval_model'): retrieval_model = document_data['retrieval_model'] else: default_retrieval_model = { @@ -920,9 +941,9 @@ class DocumentService: and ('process_rule' not in args and not args['process_rule']): raise ValueError("Data source or Process rule is required") else: - if 'data_source' in args and args['data_source']: + if args.get('data_source'): DocumentService.data_source_args_validate(args) - if 'process_rule' in args and args['process_rule']: + if args.get('process_rule'): DocumentService.process_rule_args_validate(args) @classmethod @@ -1122,10 +1143,7 @@ class SegmentService: model=dataset.embedding_model ) # calc embedding use tokens - model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, + tokens = embedding_model.get_text_embedding_num_tokens( texts=[content] ) lock_name = 'add_segment_lock_document_id_{}'.format(document.id) @@ -1193,10 +1211,7 @@ class SegmentService: tokens = 0 if dataset.indexing_technique == 'high_quality' and embedding_model: # calc embedding use tokens - model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, + tokens = embedding_model.get_text_embedding_num_tokens( texts=[content] ) segment_document = DocumentSegment( @@ -1241,15 +1256,35 @@ class SegmentService: cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") + if 'enabled' in args and args['enabled'] is not None: + action = args['enabled'] + if segment.enabled != action: + if not action: + segment.enabled = action + segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_by = current_user.id + db.session.add(segment) + db.session.commit() + # Set cache to prevent indexing the same segment multiple times + redis_client.setex(indexing_cache_key, 600, 1) + disable_segment_from_index_task.delay(segment.id) + return segment + if not segment.enabled: + if 'enabled' in args and args['enabled'] is not None: + if not args['enabled']: + raise ValueError("Can't update disabled segment") + else: + raise ValueError("Can't update disabled segment") try: content = args['content'] if segment.content == content: if document.doc_form == 'qa_model': segment.answer = args['answer'] - if 'keywords' in args and args['keywords']: + if args.get('keywords'): segment.keywords = args['keywords'] - if 'enabled' in args and args['enabled'] is not None: - segment.enabled = args['enabled'] + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None db.session.add(segment) db.session.commit() # update segment index task @@ -1279,10 +1314,7 @@ class SegmentService: ) # calc embedding use tokens - model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, + tokens = embedding_model.get_text_embedding_num_tokens( texts=[content] ) segment.content = content @@ -1294,12 +1326,16 @@ class SegmentService: segment.completed_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.updated_by = current_user.id segment.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None if document.doc_form == 'qa_model': segment.answer = args['answer'] db.session.add(segment) db.session.commit() # update segment vector index VectorService.update_segment_vector(args['keywords'], segment, dataset) + except Exception as e: logging.exception("update segment index failed") segment.enabled = False diff --git a/api/services/enterprise/enterprise_feature_service.py b/api/services/enterprise/enterprise_feature_service.py deleted file mode 100644 index fe33349aa8..0000000000 --- a/api/services/enterprise/enterprise_feature_service.py +++ /dev/null @@ -1,28 +0,0 @@ -from flask import current_app -from pydantic import BaseModel - -from services.enterprise.enterprise_service import EnterpriseService - - -class EnterpriseFeatureModel(BaseModel): - sso_enforced_for_signin: bool = False - sso_enforced_for_signin_protocol: str = '' - - -class EnterpriseFeatureService: - - @classmethod - def get_enterprise_features(cls) -> EnterpriseFeatureModel: - features = EnterpriseFeatureModel() - - if current_app.config['ENTERPRISE_ENABLED']: - cls._fulfill_params_from_enterprise(features) - - return features - - @classmethod - def _fulfill_params_from_enterprise(cls, features): - enterprise_info = EnterpriseService.get_info() - - features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] - features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] diff --git a/api/services/enterprise/enterprise_sso_service.py b/api/services/enterprise/enterprise_sso_service.py deleted file mode 100644 index d8e19f23bf..0000000000 --- a/api/services/enterprise/enterprise_sso_service.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging - -from models.account import Account, AccountStatus -from services.account_service import AccountService, TenantService -from services.enterprise.base import EnterpriseRequest - -logger = logging.getLogger(__name__) - - -class EnterpriseSSOService: - - @classmethod - def get_sso_saml_login(cls) -> str: - return EnterpriseRequest.send_request('GET', '/sso/saml/login') - - @classmethod - def post_sso_saml_acs(cls, saml_response: str) -> str: - response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response}) - if 'email' not in response or response['email'] is None: - logger.exception(response) - raise Exception('Saml response is invalid') - - return cls.login_with_email(response.get('email')) - - @classmethod - def get_sso_oidc_login(cls): - return EnterpriseRequest.send_request('GET', '/sso/oidc/login') - - @classmethod - def get_sso_oidc_callback(cls, args: dict): - state_from_query = args['state'] - code_from_query = args['code'] - state_from_cookies = args['oidc-state'] - - if state_from_cookies != state_from_query: - raise Exception('invalid state or code') - - response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query}) - if 'email' not in response or response['email'] is None: - logger.exception(response) - raise Exception('OIDC response is invalid') - - return cls.login_with_email(response.get('email')) - - @classmethod - def login_with_email(cls, email: str) -> str: - account = Account.query.filter_by(email=email).first() - if account is None: - raise Exception('account not found, please contact system admin to invite you to join in a workspace') - - if account.status == AccountStatus.BANNED: - raise Exception('account is banned, please contact system admin') - - tenants = TenantService.get_join_tenants(account) - if len(tenants) == 0: - raise Exception("workspace not found, please contact system admin to invite you to join in a workspace") - - token = AccountService.get_account_jwt_token(account) - - return token diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 6cdd5090ae..77bb5e08c3 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -4,10 +4,10 @@ from typing import Optional from flask import current_app from pydantic import BaseModel -from core.entities.model_entities import ModelStatus, ModelWithProviderEntity +from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.entities.provider_entities import QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ModelType, ProviderModel +from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ( ConfigurateMethod, ModelCredentialSchema, @@ -79,13 +79,6 @@ class ProviderResponse(BaseModel): ) -class ModelResponse(ProviderModel): - """ - Model class for model response. - """ - status: ModelStatus - - class ProviderWithModelsResponse(BaseModel): """ Model class for provider with models response. @@ -95,7 +88,7 @@ class ProviderWithModelsResponse(BaseModel): icon_small: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None status: CustomConfigurationStatus - models: list[ModelResponse] + models: list[ProviderModelWithStatusEntity] def __init__(self, **data) -> None: super().__init__(**data) diff --git a/api/services/errors/app.py b/api/services/errors/app.py index 7c4ca99c2a..87e9e9247d 100644 --- a/api/services/errors/app.py +++ b/api/services/errors/app.py @@ -1,2 +1,6 @@ class MoreLikeThisDisabledError(Exception): pass + + +class WorkflowHashNotEqualError(Exception): + pass diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 3cf51d11a0..36cbc3902b 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -2,6 +2,7 @@ from flask import current_app from pydantic import BaseModel from services.billing_service import BillingService +from services.enterprise.enterprise_service import EnterpriseService class SubscriptionModel(BaseModel): @@ -28,6 +29,14 @@ class FeatureModel(BaseModel): documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) docs_processing: str = 'standard' can_replace_logo: bool = False + model_load_balancing_enabled: bool = False + + +class SystemFeatureModel(BaseModel): + sso_enforced_for_signin: bool = False + sso_enforced_for_signin_protocol: str = '' + sso_enforced_for_web: bool = False + sso_enforced_for_web_protocol: str = '' class FeatureService: @@ -43,9 +52,19 @@ class FeatureService: return features + @classmethod + def get_system_features(cls) -> SystemFeatureModel: + system_features = SystemFeatureModel() + + if current_app.config['ENTERPRISE_ENABLED']: + cls._fulfill_params_from_enterprise(system_features) + + return system_features + @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO'] + features.model_load_balancing_enabled = current_app.config['MODEL_LB_ENABLED'] @classmethod def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): @@ -55,21 +74,40 @@ class FeatureService: features.billing.subscription.plan = billing_info['subscription']['plan'] features.billing.subscription.interval = billing_info['subscription']['interval'] - features.members.size = billing_info['members']['size'] - features.members.limit = billing_info['members']['limit'] + if 'members' in billing_info: + features.members.size = billing_info['members']['size'] + features.members.limit = billing_info['members']['limit'] - features.apps.size = billing_info['apps']['size'] - features.apps.limit = billing_info['apps']['limit'] + if 'apps' in billing_info: + features.apps.size = billing_info['apps']['size'] + features.apps.limit = billing_info['apps']['limit'] - features.vector_space.size = billing_info['vector_space']['size'] - features.vector_space.limit = billing_info['vector_space']['limit'] + if 'vector_space' in billing_info: + features.vector_space.size = billing_info['vector_space']['size'] + features.vector_space.limit = billing_info['vector_space']['limit'] - features.documents_upload_quota.size = billing_info['documents_upload_quota']['size'] - features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit'] + if 'documents_upload_quota' in billing_info: + features.documents_upload_quota.size = billing_info['documents_upload_quota']['size'] + features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit'] - features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] - features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] + if 'annotation_quota_limit' in billing_info: + features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] + features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] - features.docs_processing = billing_info['docs_processing'] - features.can_replace_logo = billing_info['can_replace_logo'] + if 'docs_processing' in billing_info: + features.docs_processing = billing_info['docs_processing'] + if 'can_replace_logo' in billing_info: + features.can_replace_logo = billing_info['can_replace_logo'] + + if 'model_load_balancing_enabled' in billing_info: + features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled'] + + @classmethod + def _fulfill_params_from_enterprise(cls, features): + enterprise_info = EnterpriseService.get_info() + + features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] + features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] + features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web'] + features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol'] diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py new file mode 100644 index 0000000000..c684c2862b --- /dev/null +++ b/api/services/model_load_balancing_service.py @@ -0,0 +1,565 @@ +import datetime +import json +import logging +from json import JSONDecodeError +from typing import Optional + +from core.entities.provider_configuration import ProviderConfiguration +from core.helper import encrypter +from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from core.model_manager import LBModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from core.model_runtime.model_providers import model_provider_factory +from core.provider_manager import ProviderManager +from extensions.ext_database import db +from models.provider import LoadBalancingModelConfig + +logger = logging.getLogger(__name__) + + +class ModelLoadBalancingService: + + def __init__(self) -> None: + self.provider_manager = ProviderManager() + + def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + enable model load balancing. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Enable model load balancing + provider_configuration.enable_model_load_balancing( + model=model, + model_type=ModelType.value_of(model_type) + ) + + def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + disable model load balancing. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # disable model load balancing + provider_configuration.disable_model_load_balancing( + model=model, + model_type=ModelType.value_of(model_type) + ) + + def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \ + -> tuple[bool, list[dict]]: + """ + Get load balancing configurations. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type = ModelType.value_of(model_type) + + # Get provider model setting + provider_model_setting = provider_configuration.get_provider_model_setting( + model_type=model_type, + model=model, + ) + + is_load_balancing_enabled = False + if provider_model_setting and provider_model_setting.load_balancing_enabled: + is_load_balancing_enabled = True + + # Get load balancing configurations + load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model + ).order_by(LoadBalancingModelConfig.created_at).all() + + if provider_configuration.custom_configuration.provider: + # check if the inherit configuration exists, + # inherit is represented for the provider or model custom credentials + inherit_config_exists = False + for load_balancing_config in load_balancing_configs: + if load_balancing_config.name == '__inherit__': + inherit_config_exists = True + break + + if not inherit_config_exists: + # Initialize the inherit configuration + inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type) + + # prepend the inherit configuration + load_balancing_configs.insert(0, inherit_config) + else: + # move the inherit configuration to the first + for i, load_balancing_config in enumerate(load_balancing_configs): + if load_balancing_config.name == '__inherit__': + inherit_config = load_balancing_configs.pop(i) + load_balancing_configs.insert(0, inherit_config) + + # Get credential form schemas from model credential schema or provider credential schema + credential_schemas = self._get_credential_schema(provider_configuration) + + # Get decoding rsa key and cipher for decrypting credentials + decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + + # fetch status and ttl for each config + datas = [] + for load_balancing_config in load_balancing_configs: + in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl( + tenant_id=tenant_id, + provider=provider, + model=model, + model_type=model_type, + config_id=load_balancing_config.id + ) + + try: + if load_balancing_config.encrypted_config: + credentials = json.loads(load_balancing_config.encrypted_config) + else: + credentials = {} + except JSONDecodeError: + credentials = {} + + # Get provider credential secret variables + credential_secret_variables = provider_configuration.extract_secret_variables( + credential_schemas.credential_form_schemas + ) + + # decrypt credentials + for variable in credential_secret_variables: + if variable in credentials: + try: + credentials[variable] = encrypter.decrypt_token_with_decoding( + credentials.get(variable), + decoding_rsa_key, + decoding_cipher_rsa + ) + except ValueError: + pass + + # Obfuscate credentials + credentials = provider_configuration.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=credential_schemas.credential_form_schemas + ) + + datas.append({ + 'id': load_balancing_config.id, + 'name': load_balancing_config.name, + 'credentials': credentials, + 'enabled': load_balancing_config.enabled, + 'in_cooldown': in_cooldown, + 'ttl': ttl + }) + + return is_load_balancing_enabled, datas + + def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \ + -> Optional[dict]: + """ + Get load balancing configuration. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :param config_id: load balancing config id + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type = ModelType.value_of(model_type) + + # Get load balancing configurations + load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id + ).first() + + if not load_balancing_model_config: + return None + + try: + if load_balancing_model_config.encrypted_config: + credentials = json.loads(load_balancing_model_config.encrypted_config) + else: + credentials = {} + except JSONDecodeError: + credentials = {} + + # Get credential form schemas from model credential schema or provider credential schema + credential_schemas = self._get_credential_schema(provider_configuration) + + # Obfuscate credentials + credentials = provider_configuration.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=credential_schemas.credential_form_schemas + ) + + return { + 'id': load_balancing_model_config.id, + 'name': load_balancing_model_config.name, + 'credentials': credentials, + 'enabled': load_balancing_model_config.enabled + } + + def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \ + -> LoadBalancingModelConfig: + """ + Initialize the inherit configuration. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Initialize the inherit configuration + inherit_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + name='__inherit__' + ) + db.session.add(inherit_config) + db.session.commit() + + return inherit_config + + def update_load_balancing_configs(self, tenant_id: str, + provider: str, + model: str, + model_type: str, + configs: list[dict]) -> None: + """ + Update load balancing configurations. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :param configs: load balancing configs + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type = ModelType.value_of(model_type) + + if not isinstance(configs, list): + raise ValueError('Invalid load balancing configs') + + current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model + ).all() + + # id as key, config as value + current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} + updated_config_ids = set() + + for config in configs: + if not isinstance(config, dict): + raise ValueError('Invalid load balancing config') + + config_id = config.get('id') + name = config.get('name') + credentials = config.get('credentials') + enabled = config.get('enabled') + + if not name: + raise ValueError('Invalid load balancing config name') + + if enabled is None: + raise ValueError('Invalid load balancing config enabled') + + # is config exists + if config_id: + config_id = str(config_id) + + if config_id not in current_load_balancing_configs_dict: + raise ValueError('Invalid load balancing config id: {}'.format(config_id)) + + updated_config_ids.add(config_id) + + load_balancing_config = current_load_balancing_configs_dict[config_id] + + # check duplicate name + for current_load_balancing_config in current_load_balancing_configs: + if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: + raise ValueError('Load balancing config name {} already exists'.format(name)) + + if credentials: + if not isinstance(credentials, dict): + raise ValueError('Invalid load balancing config credentials') + + # validate custom provider config + credentials = self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type, + model=model, + credentials=credentials, + load_balancing_model_config=load_balancing_config, + validate=False + ) + + # update load balancing config + load_balancing_config.encrypted_config = json.dumps(credentials) + + load_balancing_config.name = name + load_balancing_config.enabled = enabled + load_balancing_config.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.commit() + + self._clear_credentials_cache(tenant_id, config_id) + else: + # create load balancing config + if name == '__inherit__': + raise ValueError('Invalid load balancing config name') + + # check duplicate name + for current_load_balancing_config in current_load_balancing_configs: + if current_load_balancing_config.name == name: + raise ValueError('Load balancing config name {} already exists'.format(name)) + + if not credentials: + raise ValueError('Invalid load balancing config credentials') + + if not isinstance(credentials, dict): + raise ValueError('Invalid load balancing config credentials') + + # validate custom provider config + credentials = self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type, + model=model, + credentials=credentials, + validate=False + ) + + # create load balancing config + load_balancing_model_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + name=name, + encrypted_config=json.dumps(credentials) + ) + + db.session.add(load_balancing_model_config) + db.session.commit() + + # get deleted config ids + deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids + for config_id in deleted_config_ids: + db.session.delete(current_load_balancing_configs_dict[config_id]) + db.session.commit() + + self._clear_credentials_cache(tenant_id, config_id) + + def validate_load_balancing_credentials(self, tenant_id: str, + provider: str, + model: str, + model_type: str, + credentials: dict, + config_id: Optional[str] = None) -> None: + """ + Validate load balancing credentials. + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credentials: credentials + :param config_id: load balancing config id + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type = ModelType.value_of(model_type) + + load_balancing_model_config = None + if config_id: + # Get load balancing config + load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id + ).first() + + if not load_balancing_model_config: + raise ValueError(f"Load balancing config {config_id} does not exist.") + + # Validate custom provider config + self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type, + model=model, + credentials=credentials, + load_balancing_model_config=load_balancing_model_config + ) + + def _custom_credentials_validate(self, tenant_id: str, + provider_configuration: ProviderConfiguration, + model_type: ModelType, + model: str, + credentials: dict, + load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, + validate: bool = True) -> dict: + """ + Validate custom credentials. + :param tenant_id: workspace id + :param provider_configuration: provider configuration + :param model_type: model type + :param model: model name + :param credentials: credentials + :param load_balancing_model_config: load balancing model config + :param validate: validate credentials + :return: + """ + # Get credential form schemas from model credential schema or provider credential schema + credential_schemas = self._get_credential_schema(provider_configuration) + + # Get provider credential secret variables + provider_credential_secret_variables = provider_configuration.extract_secret_variables( + credential_schemas.credential_form_schemas + ) + + if load_balancing_model_config: + try: + # fix origin data + if load_balancing_model_config.encrypted_config: + original_credentials = json.loads(load_balancing_model_config.encrypted_config) + else: + original_credentials = {} + except JSONDecodeError: + original_credentials = {} + + # encrypt credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == '[__HIDDEN__]' and key in original_credentials: + credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) + + if validate: + if isinstance(credential_schemas, ModelCredentialSchema): + credentials = model_provider_factory.model_credentials_validate( + provider=provider_configuration.provider.provider, + model_type=model_type, + model=model, + credentials=credentials + ) + else: + credentials = model_provider_factory.provider_credentials_validate( + provider=provider_configuration.provider.provider, + credentials=credentials + ) + + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + return credentials + + def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \ + -> ModelCredentialSchema | ProviderCredentialSchema: + """ + Get form schemas. + :param provider_configuration: provider configuration + :return: + """ + # Get credential form schemas from model credential schema or provider credential schema + if provider_configuration.provider.model_credential_schema: + credential_schema = provider_configuration.provider.model_credential_schema + else: + credential_schema = provider_configuration.provider.provider_credential_schema + + return credential_schema + + def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: + """ + Clear credentials cache. + :param tenant_id: workspace id + :param config_id: load balancing config id + :return: + """ + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=tenant_id, + identity_id=config_id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + ) + + provider_model_credentials_cache.delete() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 5a4342ae03..385af685f9 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -6,7 +6,7 @@ from typing import Optional, cast import requests from flask import current_app -from core.entities.model_entities import ModelStatus +from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -16,7 +16,6 @@ from services.entities.model_provider_entities import ( CustomConfigurationResponse, CustomConfigurationStatus, DefaultModelResponse, - ModelResponse, ModelWithProviderEntityResponse, ProviderResponse, ProviderWithModelsResponse, @@ -303,6 +302,9 @@ class ModelProviderService: if model.deprecated: continue + if model.status != ModelStatus.ACTIVE: + continue + provider_models[model.provider.provider].append(model) # convert to ProviderWithModelsResponse list @@ -313,24 +315,22 @@ class ModelProviderService: first_model = models[0] - has_active_models = any([model.status == ModelStatus.ACTIVE for model in models]) - providers_with_models.append( ProviderWithModelsResponse( provider=provider, label=first_model.provider.label, icon_small=first_model.provider.icon_small, icon_large=first_model.provider.icon_large, - status=CustomConfigurationStatus.ACTIVE - if has_active_models else CustomConfigurationStatus.NO_CONFIGURE, - models=[ModelResponse( + status=CustomConfigurationStatus.ACTIVE, + models=[ProviderModelWithStatusEntity( model=model.model, label=model.label, model_type=model.model_type, features=model.features, fetch_from=model.fetch_from, model_properties=model.model_properties, - status=model.status + status=model.status, + load_balancing_enabled=model.load_balancing_enabled ) for model in models] ) ) @@ -486,6 +486,54 @@ class ModelProviderService: # Switch preferred provider type provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) + def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + enable model. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Enable model + provider_configuration.enable_model( + model=model, + model_type=ModelType.value_of(model_type) + ) + + def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + disable model. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Enable model + provider_configuration.disable_model( + model=model, + model_type=ModelType.value_of(model_type) + ) + def free_quota_submit(self, tenant_id: str, provider: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 3d36fb80af..6a155922b4 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -86,6 +86,7 @@ class RecommendedAppService: 'description': site.description, 'copyright': site.copyright, 'privacy_policy': site.privacy_policy, + 'custom_disclaimer': site.custom_disclaimer, 'category': recommended_app.category, 'position': recommended_app.position, 'is_listed': recommended_app.is_listed @@ -94,7 +95,7 @@ class RecommendedAppService: categories.add(recommended_app.category) # add category to categories - return {'recommended_apps': recommended_apps_result, 'categories': list(categories)} + return {'recommended_apps': recommended_apps_result, 'categories': sorted(list(categories))} @classmethod def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: diff --git a/api/services/tools_manage_service.py b/api/services/tools/api_tools_manage_service.py similarity index 57% rename from api/services/tools_manage_service.py rename to api/services/tools/api_tools_manage_service.py index ec4e89bd14..9a0d6ca8d9 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -4,97 +4,30 @@ import logging from httpx import get from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, ToolCredentialsOption, ToolProviderCredentials, ) -from core.tools.entities.user_entities import UserTool, UserToolProvider -from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError -from core.tools.provider.api_tool_provider import ApiBasedToolProviderController -from core.tools.provider.builtin._positions import BuiltinToolProviderSort -from core.tools.provider.tool_provider import ToolProviderController +from core.tools.provider.api_tool_provider import ApiToolProviderController +from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolConfigurationManager from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db -from models.tools import ApiToolProvider, BuiltinToolProvider -from services.model_provider_service import ModelProviderService -from services.tools_transform_service import ToolTransformService +from models.tools import ApiToolProvider +from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) -class ToolManageService: +class ApiToolManageService: @staticmethod - def list_tool_providers(user_id: str, tenant_id: str): - """ - list tool providers - - :return: the list of tool providers - """ - providers = ToolManager.user_list_providers( - user_id, tenant_id - ) - - # add icon - for provider in providers: - ToolTransformService.repack_provider(provider) - - result = [provider.to_dict() for provider in providers] - - return result - - @staticmethod - def list_builtin_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: - """ - list builtin tool provider tools - """ - provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) - tools = provider_controller.get_tools() - - tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) - # check if user has added the provider - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt_tool_credentials(credentials) - - result = [] - for tool in tools: - result.append(ToolTransformService.tool_to_user_tool( - tool=tool, credentials=credentials, tenant_id=tenant_id - )) - - return result - - @staticmethod - def list_builtin_provider_credentials_schema( - provider_name - ): - """ - list builtin provider credentials schema - - :return: the list of tool providers - """ - provider = ToolManager.get_builtin_provider(provider_name) - return jsonable_encoder([ - v for _, v in (provider.credentials_schema or {}).items() - ]) - - @staticmethod - def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]: + def parser_api_schema(schema: str) -> list[ApiToolBundle]: """ parse api schema to tool bundle """ @@ -162,7 +95,7 @@ class ToolManageService: raise ValueError(f'invalid schema: {str(e)}') @staticmethod - def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiBasedToolBundle]: + def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]: """ convert schema to tool bundles @@ -177,7 +110,7 @@ class ToolManageService: @staticmethod def create_api_tool_provider( user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str + schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] ): """ create api tool provider @@ -197,7 +130,7 @@ class ToolManageService: # parse openapi to tool bundle extra_info = {} # extra info like description will be set here - tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) if len(tool_bundles) > 100: raise ValueError('the number of apis should be less than 100') @@ -213,7 +146,8 @@ class ToolManageService: schema_type_str=schema_type, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str={}, - privacy_policy=privacy_policy + privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer ) if 'auth_type' not in credentials: @@ -223,7 +157,7 @@ class ToolManageService: auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) # create provider entity - provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type) + provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) @@ -235,6 +169,9 @@ class ToolManageService: db.session.add(db_provider) db.session.commit() + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels) + return { 'result': 'success' } @staticmethod @@ -256,7 +193,7 @@ class ToolManageService: schema = response.text # try to parse schema, avoid SSRF attack - ToolManageService.parser_api_schema(schema) + ApiToolManageService.parser_api_schema(schema) except Exception as e: logger.error(f"parse api schema error: {str(e)}") raise ValueError('invalid schema, please check the url you provided') @@ -280,91 +217,20 @@ class ToolManageService: if provider is None: raise ValueError(f'you have not added provider {provider}') - return [ - ToolTransformService.tool_to_user_tool(tool_bundle) for tool_bundle in provider.tools - ] - - @staticmethod - def update_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str, credentials: dict - ): - """ - update builtin tool provider - """ - # get if the provider exists - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() - - try: - # get provider - provider_controller = ToolManager.get_builtin_provider(provider_name) - if not provider_controller.need_credentials: - raise ValueError(f'provider {provider_name} does not need credentials') - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) - # get original credentials if exists - if provider is not None: - original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] - # validate credentials - provider_controller.validate_credentials(credentials) - # encrypt credentials - credentials = tool_configuration.encrypt_tool_credentials(credentials) - except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e: - raise ValueError(str(e)) - - if provider is None: - # create provider - provider = BuiltinToolProvider( - tenant_id=tenant_id, - user_id=user_id, - provider=provider_name, - encrypted_credentials=json.dumps(credentials), - ) - - db.session.add(provider) - db.session.commit() - - else: - provider.encrypted_credentials = json.dumps(credentials) - db.session.add(provider) - db.session.commit() - - # delete cache - tool_configuration.delete_tool_credentials_cache() - - return { 'result': 'success' } - - @staticmethod - def get_builtin_tool_provider_credentials( - user_id: str, tenant_id: str, provider: str - ): - """ - get builtin tool provider credentials - """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() - - if provider is None: - return {} + controller = ToolTransformService.api_provider_to_controller(db_provider=provider) + labels = ToolLabelManager.get_tool_labels(controller) - provider_controller = ToolManager.get_builtin_provider(provider.provider) - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) - credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) - credentials = tool_configuration.mask_tool_credentials(credentials) - return credentials + return [ + ToolTransformService.tool_to_user_tool( + tool_bundle, + labels=labels, + ) for tool_bundle in provider.tools + ] @staticmethod def update_api_tool_provider( user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str + schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] ): """ update api tool provider @@ -384,7 +250,7 @@ class ToolManageService: # parse openapi to tool bundle extra_info = {} # extra info like description will be set here - tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) # update db provider provider.name = provider_name @@ -394,6 +260,7 @@ class ToolManageService: provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy + provider.custom_disclaimer = custom_disclaimer if 'auth_type' not in credentials: raise ValueError('auth_type is required') @@ -402,7 +269,7 @@ class ToolManageService: auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) # create provider entity - provider_controller = ApiBasedToolProviderController.from_db(provider, auth_type) + provider_controller = ApiToolProviderController.from_db(provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) @@ -425,84 +292,11 @@ class ToolManageService: # delete cache tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - - @staticmethod - def delete_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): - """ - delete tool provider - """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() - - if provider is None: - raise ValueError(f'you have not added provider {provider_name}') - - db.session.delete(provider) - db.session.commit() - - # delete cache - provider_controller = ToolManager.get_builtin_provider(provider_name) - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) - tool_configuration.delete_tool_credentials_cache() + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels) return { 'result': 'success' } - @staticmethod - def get_builtin_tool_provider_icon( - provider: str - ): - """ - get tool provider icon and it's mimetype - """ - icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) - with open(icon_path, 'rb') as f: - icon_bytes = f.read() - - return icon_bytes, mime_type - - @staticmethod - def get_model_tool_provider_icon( - provider: str - ): - """ - get tool provider icon and it's mimetype - """ - - service = ModelProviderService() - icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US') - - if icon_bytes is None: - raise ValueError(f'provider {provider} does not exists') - - return icon_bytes, mime_type - - @staticmethod - def list_model_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: - """ - list model tool provider tools - """ - provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider) - tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) - - result = [ - UserTool( - author=tool.identity.author, - name=tool.identity.name, - label=tool.identity.label, - description=tool.description.human, - parameters=tool.parameters or [] - ) for tool in tools - ] - - return jsonable_encoder(result) - @staticmethod def delete_api_tool_provider( user_id: str, tenant_id: str, provider_name: str @@ -581,7 +375,7 @@ class ToolManageService: auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) # create provider entity - provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type) + provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) @@ -602,7 +396,7 @@ class ToolManageService: provider_controller.validate_credentials_format(credentials) # get tool tool = provider_controller.get_tool(tool_name) - tool = tool.fork_tool_runtime(meta={ + tool = tool.fork_tool_runtime(runtime={ 'credentials': credentials, 'tenant_id': tenant_id, }) @@ -612,49 +406,6 @@ class ToolManageService: return { 'result': result or 'empty response' } - @staticmethod - def list_builtin_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: - """ - list builtin tools - """ - # get all builtin providers - provider_controllers = ToolManager.list_builtin_providers() - - # get all user added providers - db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id - ).all() or [] - - # find provider - find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) - - result: list[UserToolProvider] = [] - - for provider_controller in provider_controllers: - # convert provider controller to user provider - user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( - provider_controller=provider_controller, - db_provider=find_provider(provider_controller.identity.name), - decrypt_credentials=True - ) - - # add icon - ToolTransformService.repack_provider(user_builtin_provider) - - tools = provider_controller.get_tools() - for tool in tools: - user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_builtin_provider.original_credentials, - )) - - result.append(user_builtin_provider) - - return BuiltinToolProviderSort.sort(result) - @staticmethod def list_api_tools( user_id: str, tenant_id: str @@ -672,6 +423,7 @@ class ToolManageService: for provider in db_providers: # convert provider controller to user provider provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) + labels = ToolLabelManager.get_tool_labels(provider_controller) user_provider = ToolTransformService.api_provider_to_user_provider( provider_controller, db_provider=provider, @@ -690,6 +442,7 @@ class ToolManageService: tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, + labels=labels )) result.append(user_provider) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py new file mode 100644 index 0000000000..2503191b63 --- /dev/null +++ b/api/services/tools/builtin_tools_manage_service.py @@ -0,0 +1,226 @@ +import json +import logging + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError +from core.tools.provider.builtin._positions import BuiltinToolProviderSort +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolConfigurationManager +from extensions.ext_database import db +from models.tools import BuiltinToolProvider +from services.tools.tools_transform_service import ToolTransformService + +logger = logging.getLogger(__name__) + + +class BuiltinToolManageService: + @staticmethod + def list_builtin_tool_provider_tools( + user_id: str, tenant_id: str, provider: str + ) -> list[UserTool]: + """ + list builtin tool provider tools + """ + provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) + tools = provider_controller.get_tools() + + tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + # check if user has added the provider + builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ).first() + + credentials = {} + if builtin_provider is not None: + # get credentials + credentials = builtin_provider.credentials + credentials = tool_provider_configurations.decrypt_tool_credentials(credentials) + + result = [] + for tool in tools: + result.append(ToolTransformService.tool_to_user_tool( + tool=tool, + credentials=credentials, + tenant_id=tenant_id, + labels=ToolLabelManager.get_tool_labels(provider_controller) + )) + + return result + + @staticmethod + def list_builtin_provider_credentials_schema( + provider_name + ): + """ + list builtin provider credentials schema + + :return: the list of tool providers + """ + provider = ToolManager.get_builtin_provider(provider_name) + return jsonable_encoder([ + v for _, v in (provider.credentials_schema or {}).items() + ]) + + @staticmethod + def update_builtin_tool_provider( + user_id: str, tenant_id: str, provider_name: str, credentials: dict + ): + """ + update builtin tool provider + """ + # get if the provider exists + provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ).first() + + try: + # get provider + provider_controller = ToolManager.get_builtin_provider(provider_name) + if not provider_controller.need_credentials: + raise ValueError(f'provider {provider_name} does not need credentials') + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + # get original credentials if exists + if provider is not None: + original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) + masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = original_credentials[name] + # validate credentials + provider_controller.validate_credentials(credentials) + # encrypt credentials + credentials = tool_configuration.encrypt_tool_credentials(credentials) + except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e: + raise ValueError(str(e)) + + if provider is None: + # create provider + provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider_name, + encrypted_credentials=json.dumps(credentials), + ) + + db.session.add(provider) + db.session.commit() + + else: + provider.encrypted_credentials = json.dumps(credentials) + db.session.add(provider) + db.session.commit() + + # delete cache + tool_configuration.delete_tool_credentials_cache() + + return { 'result': 'success' } + + @staticmethod + def get_builtin_tool_provider_credentials( + user_id: str, tenant_id: str, provider: str + ): + """ + get builtin tool provider credentials + """ + provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ).first() + + if provider is None: + return {} + + provider_controller = ToolManager.get_builtin_provider(provider.provider) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) + credentials = tool_configuration.mask_tool_credentials(credentials) + return credentials + + @staticmethod + def delete_builtin_tool_provider( + user_id: str, tenant_id: str, provider_name: str + ): + """ + delete tool provider + """ + provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ).first() + + if provider is None: + raise ValueError(f'you have not added provider {provider_name}') + + db.session.delete(provider) + db.session.commit() + + # delete cache + provider_controller = ToolManager.get_builtin_provider(provider_name) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration.delete_tool_credentials_cache() + + return { 'result': 'success' } + + @staticmethod + def get_builtin_tool_provider_icon( + provider: str + ): + """ + get tool provider icon and it's mimetype + """ + icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) + with open(icon_path, 'rb') as f: + icon_bytes = f.read() + + return icon_bytes, mime_type + + @staticmethod + def list_builtin_tools( + user_id: str, tenant_id: str + ) -> list[UserToolProvider]: + """ + list builtin tools + """ + # get all builtin providers + provider_controllers = ToolManager.list_builtin_providers() + + # get all user added providers + db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id + ).all() or [] + + # find provider + find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + + result: list[UserToolProvider] = [] + + for provider_controller in provider_controllers: + # convert provider controller to user provider + user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=find_provider(provider_controller.identity.name), + decrypt_credentials=True + ) + + # add icon + ToolTransformService.repack_provider(user_builtin_provider) + + tools = provider_controller.get_tools() + for tool in tools: + user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, + tool=tool, + credentials=user_builtin_provider.original_credentials, + labels=ToolLabelManager.get_tool_labels(provider_controller) + )) + + result.append(user_builtin_provider) + + return BuiltinToolProviderSort.sort(result) + \ No newline at end of file diff --git a/api/services/tools/tool_labels_service.py b/api/services/tools/tool_labels_service.py new file mode 100644 index 0000000000..8a6aa025f2 --- /dev/null +++ b/api/services/tools/tool_labels_service.py @@ -0,0 +1,8 @@ +from core.tools.entities.tool_entities import ToolLabel +from core.tools.entities.values import default_tool_labels + + +class ToolLabelsService: + @classmethod + def list_tool_labels(cls) -> list[ToolLabel]: + return default_tool_labels \ No newline at end of file diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py new file mode 100644 index 0000000000..76d2f53ae8 --- /dev/null +++ b/api/services/tools/tools_manage_service.py @@ -0,0 +1,29 @@ +import logging + +from core.tools.entities.api_entities import UserToolProviderTypeLiteral +from core.tools.tool_manager import ToolManager +from services.tools.tools_transform_service import ToolTransformService + +logger = logging.getLogger(__name__) + + +class ToolCommonService: + @staticmethod + def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None): + """ + list tool providers + + :return: the list of tool providers + """ + providers = ToolManager.user_list_providers( + user_id, tenant_id, typ + ) + + # add icon + for provider in providers: + ToolTransformService.repack_provider(provider) + + result = [provider.to_dict() for provider in providers] + + return result + \ No newline at end of file diff --git a/api/services/tools_transform_service.py b/api/services/tools/tools_transform_service.py similarity index 75% rename from api/services/tools_transform_service.py rename to api/services/tools/tools_transform_service.py index 3ef9f52e62..ba8c20d79b 100644 --- a/api/services/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,15 +5,21 @@ from typing import Optional, Union from flask import current_app from core.model_runtime.entities.common_entities import I18nObject -from core.tools.entities.tool_bundle import ApiBasedToolBundle -from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderCredentials -from core.tools.entities.user_entities import UserTool, UserToolProvider -from core.tools.provider.api_tool_provider import ApiBasedToolProviderController +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolParameter, + ToolProviderCredentials, + ToolProviderType, +) +from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController -from core.tools.provider.model_tool_provider import ModelToolProviderController +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController from core.tools.tool.tool import Tool +from core.tools.tool.workflow_tool import WorkflowTool from core.tools.utils.configuration import ToolConfigurationManager -from models.tools import ApiToolProvider, BuiltinToolProvider +from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider logger = logging.getLogger(__name__) @@ -26,11 +32,9 @@ class ToolTransformService: url_prefix = (current_app.config.get("CONSOLE_API_URL") + "/console/api/workspaces/current/tool-provider/") - if provider_type == UserToolProvider.ProviderType.BUILTIN.value: + if provider_type == ToolProviderType.BUILT_IN.value: return url_prefix + 'builtin/' + provider_name + '/icon' - elif provider_type == UserToolProvider.ProviderType.MODEL.value: - return url_prefix + 'model/' + provider_name + '/icon' - elif provider_type == UserToolProvider.ProviderType.API.value: + elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: try: return json.loads(icon) except: @@ -65,7 +69,7 @@ class ToolTransformService: def builtin_provider_to_user_provider( provider_controller: BuiltinToolProviderController, db_provider: Optional[BuiltinToolProvider], - decrypt_credentials: bool = True + decrypt_credentials: bool = True, ) -> UserToolProvider: """ convert provider controller to user provider @@ -83,10 +87,11 @@ class ToolTransformService: en_US=provider_controller.identity.label.en_US, zh_Hans=provider_controller.identity.label.zh_Hans, ), - type=UserToolProvider.ProviderType.BUILTIN, + type=ToolProviderType.BUILT_IN, masked_credentials={}, is_team_authorization=False, - tools=[] + tools=[], + labels=provider_controller.tool_labels ) # get credentials schema @@ -122,24 +127,62 @@ class ToolTransformService: @staticmethod def api_provider_to_controller( db_provider: ApiToolProvider, - ) -> ApiBasedToolProviderController: + ) -> ApiToolProviderController: """ convert provider controller to user provider """ # package tool provider controller - controller = ApiBasedToolProviderController.from_db( + controller = ApiToolProviderController.from_db( db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE ) return controller + + @staticmethod + def workflow_provider_to_controller( + db_provider: WorkflowToolProvider + ) -> WorkflowToolProviderController: + """ + convert provider controller to provider + """ + return WorkflowToolProviderController.from_db(db_provider) + + @staticmethod + def workflow_provider_to_user_provider( + provider_controller: WorkflowToolProviderController, + labels: list[str] = None + ): + """ + convert provider controller to user provider + """ + return UserToolProvider( + id=provider_controller.provider_id, + author=provider_controller.identity.author, + name=provider_controller.identity.name, + description=I18nObject( + en_US=provider_controller.identity.description.en_US, + zh_Hans=provider_controller.identity.description.zh_Hans, + ), + icon=provider_controller.identity.icon, + label=I18nObject( + en_US=provider_controller.identity.label.en_US, + zh_Hans=provider_controller.identity.label.zh_Hans, + ), + type=ToolProviderType.WORKFLOW, + masked_credentials={}, + is_team_authorization=True, + tools=[], + labels=labels or [] + ) @staticmethod def api_provider_to_user_provider( - provider_controller: ApiBasedToolProviderController, + provider_controller: ApiToolProviderController, db_provider: ApiToolProvider, - decrypt_credentials: bool = True + decrypt_credentials: bool = True, + labels: list[str] = None ) -> UserToolProvider: """ convert provider controller to user provider @@ -164,10 +207,11 @@ class ToolTransformService: en_US=db_provider.name, zh_Hans=db_provider.name, ), - type=UserToolProvider.ProviderType.API, + type=ToolProviderType.API, masked_credentials={}, is_team_authorization=True, - tools=[] + tools=[], + labels=labels or [] ) if decrypt_credentials: @@ -185,41 +229,19 @@ class ToolTransformService: return result - @staticmethod - def model_provider_to_user_provider( - db_provider: ModelToolProviderController, - ) -> UserToolProvider: - """ - convert provider controller to user provider - """ - return UserToolProvider( - id=db_provider.identity.name, - author=db_provider.identity.author, - name=db_provider.identity.name, - description=I18nObject( - en_US=db_provider.identity.description.en_US, - zh_Hans=db_provider.identity.description.zh_Hans, - ), - icon=db_provider.identity.icon, - label=I18nObject( - en_US=db_provider.identity.label.en_US, - zh_Hans=db_provider.identity.label.zh_Hans, - ), - type=UserToolProvider.ProviderType.MODEL, - masked_credentials={}, - is_team_authorization=db_provider.is_active, - ) - @staticmethod def tool_to_user_tool( - tool: Union[ApiBasedToolBundle, Tool], credentials: dict = None, tenant_id: str = None + tool: Union[ApiToolBundle, WorkflowTool, Tool], + credentials: dict = None, + tenant_id: str = None, + labels: list[str] = None ) -> UserTool: """ convert tool to user tool """ if isinstance(tool, Tool): # fork tool runtime - tool = tool.fork_tool_runtime(meta={ + tool = tool.fork_tool_runtime(runtime={ 'credentials': credentials, 'tenant_id': tenant_id, }) @@ -241,17 +263,15 @@ class ToolTransformService: if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: current_parameters.append(runtime_parameter) - user_tool = UserTool( + return UserTool( author=tool.identity.author, name=tool.identity.name, label=tool.identity.label, description=tool.description.human, - parameters=current_parameters + parameters=current_parameters, + labels=labels ) - - return user_tool - - if isinstance(tool, ApiBasedToolBundle): + if isinstance(tool, ApiToolBundle): return UserTool( author=tool.author, name=tool.operation_id, @@ -263,5 +283,6 @@ class ToolTransformService: en_US=tool.summary or '', zh_Hans=tool.summary or '' ), - parameters=tool.parameters + parameters=tool.parameters, + labels=labels ) \ No newline at end of file diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py new file mode 100644 index 0000000000..e89d94160c --- /dev/null +++ b/api/services/tools/workflow_tools_manage_service.py @@ -0,0 +1,326 @@ +import json +from datetime import datetime + +from sqlalchemy import or_ + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.entities.api_entities import UserToolProvider +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from extensions.ext_database import db +from models.model import App +from models.tools import WorkflowToolProvider +from models.workflow import Workflow +from services.tools.tools_transform_service import ToolTransformService + + +class WorkflowToolManageService: + """ + Service class for managing workflow tools. + """ + @classmethod + def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str, + label: str, icon: dict, description: str, + parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + """ + Create a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param name: the name + :param icon: the icon + :param description: the description + :param parameters: the parameters + :param privacy_policy: the privacy policy + :return: the created tool + """ + WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) + + # check if the name is unique + existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id) + ).first() + + if existing_workflow_tool_provider is not None: + raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists') + + app: App = db.session.query(App).filter( + App.id == workflow_app_id, + App.tenant_id == tenant_id + ).first() + + if app is None: + raise ValueError(f'App {workflow_app_id} not found') + + workflow: Workflow = app.workflow + if workflow is None: + raise ValueError(f'Workflow not found for app {workflow_app_id}') + + workflow_tool_provider = WorkflowToolProvider( + tenant_id=tenant_id, + user_id=user_id, + app_id=workflow_app_id, + name=name, + label=label, + icon=json.dumps(icon), + description=description, + parameter_configuration=json.dumps(parameters), + privacy_policy=privacy_policy, + version=workflow.version, + ) + + try: + WorkflowToolProviderController.from_db(workflow_tool_provider) + except Exception as e: + raise ValueError(str(e)) + + db.session.add(workflow_tool_provider) + db.session.commit() + + return { + 'result': 'success' + } + + + @classmethod + def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str, + name: str, label: str, icon: dict, description: str, + parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + """ + Update a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param tool: the tool + :return: the updated tool + """ + WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) + + # check if the name is unique + existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id + ).first() + + if existing_workflow_tool_provider is not None: + raise ValueError(f'Tool with name {name} already exists') + + workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == workflow_tool_id + ).first() + + if workflow_tool_provider is None: + raise ValueError(f'Tool {workflow_tool_id} not found') + + app: App = db.session.query(App).filter( + App.id == workflow_tool_provider.app_id, + App.tenant_id == tenant_id + ).first() + + if app is None: + raise ValueError(f'App {workflow_tool_provider.app_id} not found') + + workflow: Workflow = app.workflow + if workflow is None: + raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}') + + workflow_tool_provider.name = name + workflow_tool_provider.label = label + workflow_tool_provider.icon = json.dumps(icon) + workflow_tool_provider.description = description + workflow_tool_provider.parameter_configuration = json.dumps(parameters) + workflow_tool_provider.privacy_policy = privacy_policy + workflow_tool_provider.version = workflow.version + workflow_tool_provider.updated_at = datetime.now() + + try: + WorkflowToolProviderController.from_db(workflow_tool_provider) + except Exception as e: + raise ValueError(str(e)) + + db.session.add(workflow_tool_provider) + db.session.commit() + + if labels is not None: + ToolLabelManager.update_tool_labels( + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), + labels + ) + + return { + 'result': 'success' + } + + @classmethod + def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: + """ + List workflow tools. + :param user_id: the user id + :param tenant_id: the tenant id + :return: the list of tools + """ + db_tools = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id + ).all() + + tools = [] + for provider in db_tools: + try: + tools.append(ToolTransformService.workflow_provider_to_controller(provider)) + except: + # skip deleted tools + pass + + labels = ToolLabelManager.get_tools_labels(tools) + + result = [] + + for tool in tools: + user_tool_provider = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=tool, + labels=labels.get(tool.provider_id, []) + ) + ToolTransformService.repack_provider(user_tool_provider) + user_tool_provider.tools = [ + ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], + labels=labels.get(tool.provider_id, []) + ) + ] + result.append(user_tool_provider) + + return result + + @classmethod + def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: + """ + Delete a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + """ + db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == workflow_tool_id + ).delete() + + db.session.commit() + + return { + 'result': 'success' + } + + @classmethod + def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: + """ + Get a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + :return: the tool + """ + db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == workflow_tool_id + ).first() + + if db_tool is None: + raise ValueError(f'Tool {workflow_tool_id} not found') + + workflow_app: App = db.session.query(App).filter( + App.id == db_tool.app_id, + App.tenant_id == tenant_id + ).first() + + if workflow_app is None: + raise ValueError(f'App {db_tool.app_id} not found') + + tool = ToolTransformService.workflow_provider_to_controller(db_tool) + + return { + 'name': db_tool.name, + 'label': db_tool.label, + 'workflow_tool_id': db_tool.id, + 'workflow_app_id': db_tool.app_id, + 'icon': json.loads(db_tool.icon), + 'description': db_tool.description, + 'parameters': jsonable_encoder(db_tool.parameter_configurations), + 'tool': ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], + labels=ToolLabelManager.get_tool_labels(tool) + ), + 'synced': workflow_app.workflow.version == db_tool.version, + 'privacy_policy': db_tool.privacy_policy, + } + + @classmethod + def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: + """ + Get a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + :return: the tool + """ + db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.app_id == workflow_app_id + ).first() + + if db_tool is None: + raise ValueError(f'Tool {workflow_app_id} not found') + + workflow_app: App = db.session.query(App).filter( + App.id == db_tool.app_id, + App.tenant_id == tenant_id + ).first() + + if workflow_app is None: + raise ValueError(f'App {db_tool.app_id} not found') + + tool = ToolTransformService.workflow_provider_to_controller(db_tool) + + return { + 'name': db_tool.name, + 'label': db_tool.label, + 'workflow_tool_id': db_tool.id, + 'workflow_app_id': db_tool.app_id, + 'icon': json.loads(db_tool.icon), + 'description': db_tool.description, + 'parameters': jsonable_encoder(db_tool.parameter_configurations), + 'tool': ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], + labels=ToolLabelManager.get_tool_labels(tool) + ), + 'synced': workflow_app.workflow.version == db_tool.version, + 'privacy_policy': db_tool.privacy_policy + } + + @classmethod + def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: + """ + List workflow tool provider tools. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + :return: the list of tools + """ + db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == workflow_tool_id + ).first() + + if db_tool is None: + raise ValueError(f'Tool {workflow_tool_id} not found') + + tool = ToolTransformService.workflow_provider_to_controller(db_tool) + + return [ + ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], + labels=ToolLabelManager.get_tool_labels(tool) + ) + ] \ No newline at end of file diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 138a5d5786..d76cd4c7ff 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -305,7 +305,7 @@ class WorkflowConverter: } request_body_json = json.dumps(request_body) - request_body_json = request_body_json.replace('\{\{', '{{').replace('\}\}', '}}') + request_body_json = request_body_json.replace(r'\{\{', '{{').replace(r'\}\}', '}}') http_request_node = { "id": f"http_request_{index}", diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 01fd3aa4a1..6235ecf0a3 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -21,6 +21,7 @@ from models.workflow import ( WorkflowNodeExecutionTriggeredFrom, WorkflowType, ) +from services.errors.app import WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter @@ -63,13 +64,20 @@ class WorkflowService: def sync_draft_workflow(self, app_model: App, graph: dict, features: dict, + unique_hash: Optional[str], account: Account) -> Workflow: """ Sync draft workflow + :raises WorkflowHashNotEqualError """ # fetch draft workflow by app_model workflow = self.get_draft_workflow(app_model=app_model) + if workflow: + # validate unique hash + if workflow.unique_hash != unique_hash: + raise WorkflowHashNotEqualError() + # validate features structure self.validate_features_structure( app_model=app_model, diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index d6dc970477..67cc03bdeb 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -2,7 +2,6 @@ import datetime import logging import time import uuid -from typing import cast import click from celery import shared_task @@ -11,7 +10,6 @@ from sqlalchemy import func from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper @@ -59,16 +57,12 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s model=dataset.embedding_model ) - model_type_instance = embedding_model.model_type_instance - model_type_instance = cast(TextEmbeddingModel, model_type_instance) for segment in content: content = segment['content'] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) # calc embedding use tokens - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, + tokens = embedding_model.get_text_embedding_num_tokens( texts=[content] ) if embedding_model else 0 max_position = db.session.query(func.max(DocumentSegment.position)).filter( diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 9cd04b4764..f29e5ef4d6 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -73,4 +73,10 @@ MOCK_SWITCH=false # CODE EXECUTION CONFIGURATION CODE_EXECUTION_ENDPOINT= -CODE_EXECUTION_API_KEY= \ No newline at end of file +CODE_EXECUTION_API_KEY= + +# Volcengine MaaS Credentials +VOLC_API_KEY= +VOLC_SECRET_KEY= +VOLC_MODEL_ENDPOINT_ID= +VOLC_EMBEDDING_ENDPOINT_ID= \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/localai/test_rerank.py b/api/tests/integration_tests/model_runtime/localai/test_rerank.py new file mode 100644 index 0000000000..a75439337e --- /dev/null +++ b/api/tests/integration_tests/model_runtime/localai/test_rerank.py @@ -0,0 +1,158 @@ +import os + +import pytest +from api.core.model_runtime.entities.rerank_entities import RerankResult + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel + + +def test_validate_credentials_for_chat_model(): + model = LocalaiRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-reranker-v2-m3', + credentials={ + 'server_url': 'hahahaha', + 'completion_type': 'completion', + } + ) + + model.validate_credentials( + model='bge-reranker-base', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL'), + 'completion_type': 'completion', + } + ) + +def test_invoke_rerank_model(): + model = LocalaiRerankModel() + + response = model.invoke( + model='bge-reranker-base', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL') + }, + query='Organic skincare products for sensitive skin', + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials" + ], + top_n=3, + score_threshold=0.75, + user="abc-123" + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 3 +import os + +import pytest +from api.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel + + +def test_validate_credentials_for_chat_model(): + model = LocalaiRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-reranker-v2-m3', + credentials={ + 'server_url': 'hahahaha', + 'completion_type': 'completion', + } + ) + + model.validate_credentials( + model='bge-reranker-base', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL'), + 'completion_type': 'completion', + } + ) + +def test_invoke_rerank_model(): + model = LocalaiRerankModel() + + response = model.invoke( + model='bge-reranker-base', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL') + }, + query='Organic skincare products for sensitive skin', + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials" + ], + top_n=3, + score_threshold=0.75, + user="abc-123" + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 3 + +def test__invoke(): + model = LocalaiRerankModel() + + # Test case 1: Empty docs + result = model._invoke( + model='bge-reranker-base', + credentials={ + 'server_url': 'https://example.com', + 'api_key': '1234567890' + }, + query='Organic skincare products for sensitive skin', + docs=[], + top_n=3, + score_threshold=0.75, + user="abc-123" + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 0 + + # Test case 2: Valid invocation + result = model._invoke( + model='bge-reranker-base', + credentials={ + 'server_url': 'https://example.com', + 'api_key': '1234567890' + }, + query='Organic skincare products for sensitive skin', + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials" + ], + top_n=3, + score_threshold=0.75, + user="abc-123" + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 3 + assert all(isinstance(doc, RerankDocument) for doc in result.docs) \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py new file mode 100644 index 0000000000..3fd2ebed4f --- /dev/null +++ b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py @@ -0,0 +1,54 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.localai.speech2text.speech2text import LocalAISpeech2text + + +def test_validate_credentials(): + model = LocalAISpeech2text() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='whisper-1', + credentials={ + 'server_url': 'invalid_url' + } + ) + + model.validate_credentials( + model='whisper-1', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL') + } + ) + + +def test_invoke_model(): + model = LocalAISpeech2text() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), 'assets') + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, 'audio.mp3') + + # Open the file and get the file object + with open(audio_file_path, 'rb') as audio_file: + file = audio_file + + result = model.invoke( + model='whisper-1', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL') + }, + file=file, + user="abc-123" + ) + + assert isinstance(result, str) + assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py b/api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py new file mode 100644 index 0000000000..3b399d604e --- /dev/null +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py @@ -0,0 +1,85 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.volcengine_maas.text_embedding.text_embedding import ( + VolcengineMaaSTextEmbeddingModel, +) + + +def test_validate_credentials(): + model = VolcengineMaaSTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': 'INVALID', + 'volc_secret_access_key': 'INVALID', + 'endpoint_id': 'INVALID', + 'base_model_name': 'Doubao-embedding', + } + ) + + model.validate_credentials( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), + 'base_model_name': 'Doubao-embedding', + }, + ) + + +def test_invoke_model(): + model = VolcengineMaaSTextEmbeddingModel() + + result = model.invoke( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), + 'base_model_name': 'Doubao-embedding', + }, + texts=[ + "hello", + "world" + ], + user="abc-123" + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens > 0 + + +def test_get_num_tokens(): + model = VolcengineMaaSTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), + 'base_model_name': 'Doubao-embedding', + }, + texts=[ + "hello", + "world" + ] + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py new file mode 100644 index 0000000000..63835d0263 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py @@ -0,0 +1,131 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.volcengine_maas.llm.llm import VolcengineMaaSLargeLanguageModel + + +def test_validate_credentials_for_chat_model(): + model = VolcengineMaaSLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': 'INVALID', + 'volc_secret_access_key': 'INVALID', + 'endpoint_id': 'INVALID', + } + ) + + model.validate_credentials( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), + } + ) + + +def test_invoke_model(): + model = VolcengineMaaSLargeLanguageModel() + + response = model.invoke( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), + 'base_model_name': 'Skylark2-pro-4k', + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.7, + 'top_p': 1.0, + 'top_k': 1, + }, + stop=['you'], + user="abc-123", + stream=False + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_stream_model(): + model = VolcengineMaaSLargeLanguageModel() + + response = model.invoke( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), + 'base_model_name': 'Skylark2-pro-4k', + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.7, + 'top_p': 1.0, + 'top_k': 1, + }, + stop=['you'], + stream=True, + user="abc-123" + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len( + chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = VolcengineMaaSLargeLanguageModel() + + response = model.get_num_tokens( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), + 'base_model_name': 'Skylark2-pro-4k', + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ], + tools=[] + ) + + assert isinstance(response, int) + assert response == 6 diff --git a/api/tests/integration_tests/utils/test_module_import_helper.py b/api/tests/integration_tests/utils/test_module_import_helper.py index 39ac41b648..256c9a911f 100644 --- a/api/tests/integration_tests/utils/test_module_import_helper.py +++ b/api/tests/integration_tests/utils/test_module_import_helper.py @@ -1,6 +1,6 @@ import os -from core.utils.module_import_helper import import_module_from_source, load_single_subclass_from_source +from core.helper.module_import_helper import import_module_from_source, load_single_subclass_from_source from tests.integration_tests.utils.parent_class import ParentClass diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py index a2b9669b90..89a40a92be 100644 --- a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -6,7 +6,7 @@ from tests.integration_tests.vdb.test_vector_store import ( ) -class TestPgvectoRSVector(AbstractVectorTest): +class PGVectoRSVectorTest(AbstractVectorTest): def __init__(self): super().__init__() self.vector = PGVectoRS( @@ -34,4 +34,4 @@ class TestPgvectoRSVector(AbstractVectorTest): assert len(ids) == 1 def test_pgvecot_rs(setup_mock_redis): - TestPgvectoRSVector().run_all_tests() + PGVectoRSVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/pgvector/__init__.py b/api/tests/integration_tests/vdb/pgvector/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py new file mode 100644 index 0000000000..851599c7ce --- /dev/null +++ b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py @@ -0,0 +1,30 @@ +from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig +from core.rag.models.document import Document +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + get_example_text, + setup_mock_redis, +) + + +class PGVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = PGVector( + collection_name=self.collection_name, + config=PGVectorConfig( + host="localhost", + port=5433, + user="postgres", + password="difyai123456", + database="dify", + ), + ) + + def search_by_full_text(self): + hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + +def test_pgvector(setup_mock_redis): + PGVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/tidb_vector/__init__.py b/api/tests/integration_tests/vdb/tidb_vector/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py new file mode 100644 index 0000000000..837a228a55 --- /dev/null +++ b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py @@ -0,0 +1,63 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig +from models.dataset import Document +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + + +@pytest.fixture +def tidb_vector(): + return TiDBVector( + collection_name='test_collection', + config=TiDBVectorConfig( + host="xxx.eu-central-1.xxx.aws.tidbcloud.com", + port="4000", + user="xxx.root", + password="xxxxxx", + database="dify" + ) + ) + + +class TiDBVectorTest(AbstractVectorTest): + def __init__(self, vector): + super().__init__() + self.vector = vector + + def text_exists(self): + exist = self.vector.text_exists(self.example_doc_id) + assert exist == False + + def search_by_vector(self): + hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 0 + + def search_by_full_text(self): + hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + assert len(ids) == 0 + + def delete_by_document_id(self): + self.vector.delete_by_document_id(document_id=self.example_doc_id) + + +def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session): + TiDBVectorTest(vector=tidb_vector).run_all_tests() + + +@pytest.fixture +def mock_session(): + with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session: + yield mock_session + + +@pytest.fixture +def setup_tidbvector_mock(tidb_vector, mock_session): + with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'): + with patch.object(tidb_vector._engine, 'connect'): + yield tidb_vector diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index 38517cf448..13f992136e 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -1,25 +1,29 @@ import os -from typing import Literal +from typing import Literal, Optional import pytest from _pytest.monkeypatch import MonkeyPatch +from jinja2 import Template -from core.helper.code_executor.code_executor import CodeExecutor +from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage +from core.helper.code_executor.entities import CodeDependency MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' class MockedCodeExecutor: @classmethod - def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: + def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], + code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict: # invoke directly - if language == 'python3': - return { - "result": 3 - } - elif language == 'jinja2': - return { - "result": "3" - } + match language: + case CodeLanguage.PYTHON3: + return { + "result": 3 + } + case CodeLanguage.JINJA2: + return { + "result": Template(code).render(inputs) + } @pytest.fixture def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py new file mode 100644 index 0000000000..ae6e7ceaa7 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -0,0 +1,11 @@ +import pytest + +from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor + +CODE_LANGUAGE = 'unsupported_language' + + +def test_unsupported_with_code_template(): + with pytest.raises(CodeExecutionException) as e: + CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code='', inputs={}) + assert str(e.value) == f'Unsupported language {CODE_LANGUAGE}' diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py index c794ae8e4b..2d798eb9c2 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -1,6 +1,10 @@ -from core.helper.code_executor.code_executor import CodeExecutor +from textwrap import dedent -CODE_LANGUAGE = 'javascript' +from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage +from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider +from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer + +CODE_LANGUAGE = CodeLanguage.JAVASCRIPT def test_javascript_plain(): @@ -10,9 +14,30 @@ def test_javascript_plain(): def test_javascript_json(): - code = """ -obj = {'Hello': 'World'} -console.log(JSON.stringify(obj)) - """ + code = dedent(""" + obj = {'Hello': 'World'} + console.log(JSON.stringify(obj)) + """) result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) assert result == '{"Hello":"World"}\n' + + +def test_javascript_with_code_template(): + result = CodeExecutor.execute_workflow_code_template( + language=CODE_LANGUAGE, code=JavascriptCodeProvider.get_default_code(), + inputs={'arg1': 'Hello', 'arg2': 'World'}) + assert result == {'result': 'HelloWorld'} + + +def test_javascript_list_default_available_packages(): + packages = JavascriptCodeProvider.get_default_available_packages() + + # no default packages available for javascript + assert len(packages) == 0 + + +def test_javascript_get_runner_script(): + runner_script = NodeJsTemplateTransformer.get_runner_script() + assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1 + assert runner_script.count(NodeJsTemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(NodeJsTemplateTransformer._result_tag) == 2 diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jina2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jina2.py deleted file mode 100644 index aae3c7acec..0000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jina2.py +++ /dev/null @@ -1,14 +0,0 @@ -import base64 - -from core.helper.code_executor.code_executor import CodeExecutor -from core.helper.code_executor.jinja2_transformer import JINJA2_PRELOAD, PYTHON_RUNNER - -CODE_LANGUAGE = 'jinja2' - - -def test_jinja2(): - template = 'Hello {{template}}' - inputs = base64.b64encode(b'{"template": "World"}').decode('utf-8') - code = PYTHON_RUNNER.replace('{{code}}', template).replace('{{inputs}}', inputs) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload=JINJA2_PRELOAD, code=code) - assert result == '<>Hello World<>\n' diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py new file mode 100644 index 0000000000..425f4cbdd4 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -0,0 +1,31 @@ +import base64 + +from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage +from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer + +CODE_LANGUAGE = CodeLanguage.JINJA2 + + +def test_jinja2(): + template = 'Hello {{template}}' + inputs = base64.b64encode(b'{"template": "World"}').decode('utf-8') + code = (Jinja2TemplateTransformer.get_runner_script() + .replace(Jinja2TemplateTransformer._code_placeholder, template) + .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)) + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, + preload=Jinja2TemplateTransformer.get_preload_script(), + code=code) + assert result == '<>Hello World<>\n' + + +def test_jinja2_with_code_template(): + result = CodeExecutor.execute_workflow_code_template( + language=CODE_LANGUAGE, code='Hello {{template}}', inputs={'template': 'World'}) + assert result == {'result': 'Hello World'} + + +def test_jinja2_get_runner_script(): + runner_script = Jinja2TemplateTransformer.get_runner_script() + assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1 + assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py index 1983bc5e6b..d265011d4c 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -1,6 +1,11 @@ -from core.helper.code_executor.code_executor import CodeExecutor +import json +from textwrap import dedent -CODE_LANGUAGE = 'python3' +from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage +from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider +from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer + +CODE_LANGUAGE = CodeLanguage.PYTHON3 def test_python3_plain(): @@ -10,9 +15,31 @@ def test_python3_plain(): def test_python3_json(): - code = """ -import json -print(json.dumps({'Hello': 'World'})) - """ + code = dedent(""" + import json + print(json.dumps({'Hello': 'World'})) + """) result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) assert result == '{"Hello": "World"}\n' + + +def test_python3_with_code_template(): + result = CodeExecutor.execute_workflow_code_template( + language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={'arg1': 'Hello', 'arg2': 'World'}) + assert result == {'result': 'HelloWorld'} + + +def test_python3_list_default_available_packages(): + packages = Python3CodeProvider.get_default_available_packages() + assert len(packages) > 0 + assert {'requests', 'httpx'}.issubset(p['name'] for p in packages) + + # check JSON serializable + assert len(str(json.dumps(packages))) > 0 + + +def test_python3_get_runner_script(): + runner_script = Python3TemplateTransformer.get_runner_script() + assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1 + assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(Python3TemplateTransformer._result_tag) == 2 diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index c41d51caf7..15cf5367d3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -4,6 +4,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.code.code_node import CodeNode from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @@ -25,7 +26,8 @@ def test_execute_code(setup_code_executor_mock): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, config={ 'id': '1', 'data': { @@ -78,7 +80,8 @@ def test_execute_code_output_validator(setup_code_executor_mock): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, config={ 'id': '1', 'data': { @@ -132,7 +135,8 @@ def test_execute_code_output_validator_depth(): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, config={ 'id': '1', 'data': { @@ -285,7 +289,8 @@ def test_execute_code_output_object_list(): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, config={ 'id': '1', 'data': { diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 63b6b7d962..ffa2741e55 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -2,6 +2,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock @@ -10,7 +11,8 @@ BASIC_NODE_DATA = { 'app_id': '1', 'workflow_id': '1', 'user_id': '1', - 'user_from': InvokeFrom.WEB_APP, + 'user_from': UserFrom.ACCOUNT, + 'invoke_from': InvokeFrom.WEB_APP, } # construct variable pool @@ -38,6 +40,7 @@ def test_get(setup_http_mock): 'headers': 'X-Header:123', 'params': 'A:b', 'body': None, + 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -95,6 +98,7 @@ def test_custom_authorization_header(setup_http_mock): 'headers': 'X-Header:123', 'params': 'A:b', 'body': None, + 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -126,6 +130,7 @@ def test_template(setup_http_mock): 'headers': 'X-Header:123\nX-Header2:{{#a.b123.args2#}}', 'params': 'A:b\nTemplate:{{#a.b123.args2#}}', 'body': None, + 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -161,6 +166,7 @@ def test_json(setup_http_mock): 'type': 'json', 'data': '{"a": "{{#a.b123.args1#}}"}' }, + 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -193,6 +199,7 @@ def test_x_www_form_urlencoded(setup_http_mock): 'type': 'x-www-form-urlencoded', 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' }, + 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -225,6 +232,7 @@ def test_form_data(setup_http_mock): 'type': 'form-data', 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' }, + 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -260,6 +268,7 @@ def test_none_data(setup_http_mock): 'type': 'none', 'data': '123123123' }, + 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 8a8a58d59f..a150be3c00 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -1,9 +1,10 @@ +import json import os from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance @@ -19,6 +20,7 @@ from models.workflow import WorkflowNodeExecutionStatus """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) @@ -28,6 +30,7 @@ def test_execute_llm(setup_openai_mock): app_id='1', workflow_id='1', user_id='1', + invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ 'id': 'llm', @@ -89,7 +92,8 @@ def test_execute_llm(setup_openai_mock): provider=CustomProviderConfiguration( credentials=credentials ) - ) + ), + model_settings=[] ), provider_instance=provider_instance, model_type_instance=model_type_instance @@ -116,3 +120,120 @@ def test_execute_llm(setup_openai_mock): assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs['text'] is not None assert result.outputs['usage']['total_tokens'] > 0 + +@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): + """ + Test execute LLM node with jinja2 + """ + node = LLMNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'llm', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5-turbo', + 'mode': 'chat', + 'completion_params': {} + }, + 'prompt_config': { + 'jinja2_variables': [{ + 'variable': 'sys_query', + 'value_selector': ['sys', 'query'] + }, { + 'variable': 'output', + 'value_selector': ['abc', 'output'] + }] + }, + 'prompt_template': [ + { + 'role': 'system', + 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}', + 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.', + 'edition_type': 'jinja2' + }, + { + 'role': 'user', + 'text': '{{#sys.query#}}', + 'jinja2_text': '{{sys_query}}', + 'edition_type': 'basic' + } + ], + 'memory': None, + 'context': { + 'enabled': False + }, + 'vision': { + 'enabled': False + } + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather today?', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny') + + credentials = { + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + + provider_instance = ModelProviderFactory().get_provider_instance('openai') + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id='1', + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration( + enabled=False + ), + custom_configuration=CustomConfiguration( + provider=CustomProviderConfiguration( + credentials=credentials + ) + ), + model_settings=[] + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance, + ) + + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + + model_config = ModelConfigWithCredentialsEntity( + model='gpt-3.5-turbo', + provider='openai', + mode='chat', + credentials=credentials, + parameters={}, + model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), + provider_model_bundle=provider_model_bundle + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert 'sunny' in json.dumps(result.process_data) + assert 'what\'s the weather today?' in json.dumps(result.process_data) \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py new file mode 100644 index 0000000000..056c78441d --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -0,0 +1,357 @@ +import json +import os +from unittest.mock import MagicMock + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from extensions.ext_database import db +from models.provider import ProviderType + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from models.workflow import WorkflowNodeExecutionStatus +from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def get_mocked_fetch_model_config( + provider: str, model: str, mode: str, + credentials: dict, +): + provider_instance = ModelProviderFactory().get_provider_instance(provider) + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id='1', + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration( + enabled=False + ), + custom_configuration=CustomConfiguration( + provider=CustomProviderConfiguration( + credentials=credentials + ) + ), + model_settings=[] + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance + ) + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model) + model_config = ModelConfigWithCredentialsEntity( + model=model, + provider=provider, + mode=mode, + credentials=credentials, + parameters={}, + model_schema=model_type_instance.get_model_schema(model), + provider_model_bundle=provider_model_bundle + ) + + return MagicMock(return_value=tuple([model_instance, model_config])) + +@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +def test_function_calling_parameter_extractor(setup_openai_mock): + """ + Test function calling for parameter extractor. + """ + node = ParameterExtractorNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'parameter-extractor', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5-turbo', + 'mode': 'chat', + 'completion_params': {} + }, + 'query': ['sys', 'query'], + 'parameters': [{ + 'name': 'location', + 'type': 'string', + 'description': 'location', + 'required': True + }], + 'instruction': '', + 'reasoning_mode': 'function_call', + 'memory': None, + } + } + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + ) + db.session.close = MagicMock() + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather in SF', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs.get('location') == 'kawaii' + assert result.outputs.get('__reason') == None + +@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +def test_instructions(setup_openai_mock): + """ + Test chat parameter extractor. + """ + node = ParameterExtractorNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'parameter-extractor', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5-turbo', + 'mode': 'chat', + 'completion_params': {} + }, + 'query': ['sys', 'query'], + 'parameters': [{ + 'name': 'location', + 'type': 'string', + 'description': 'location', + 'required': True + }], + 'reasoning_mode': 'function_call', + 'instruction': '{{#sys.query#}}', + 'memory': None, + } + } + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + ) + db.session.close = MagicMock() + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather in SF', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs.get('location') == 'kawaii' + assert result.outputs.get('__reason') == None + + process_data = result.process_data + + process_data.get('prompts') + + for prompt in process_data.get('prompts'): + if prompt.get('role') == 'system': + assert 'what\'s the weather in SF' in prompt.get('text') + +@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +def test_chat_parameter_extractor(setup_anthropic_mock): + """ + Test chat parameter extractor. + """ + node = ParameterExtractorNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'parameter-extractor', + 'model': { + 'provider': 'anthropic', + 'name': 'claude-2', + 'mode': 'chat', + 'completion_params': {} + }, + 'query': ['sys', 'query'], + 'parameters': [{ + 'name': 'location', + 'type': 'string', + 'description': 'location', + 'required': True + }], + 'reasoning_mode': 'prompt', + 'instruction': '', + 'memory': None, + } + } + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider='anthropic', model='claude-2', mode='chat', credentials={ + 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') + } + ) + db.session.close = MagicMock() + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather in SF', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs.get('location') == '' + assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' + prompts = result.process_data.get('prompts') + + for prompt in prompts: + if prompt.get('role') == 'user': + if '' in prompt.get('text'): + assert '\n{"type": "object"' in prompt.get('text') + +@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) +def test_completion_parameter_extractor(setup_openai_mock): + """ + Test completion parameter extractor. + """ + node = ParameterExtractorNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'parameter-extractor', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5-turbo-instruct', + 'mode': 'completion', + 'completion_params': {} + }, + 'query': ['sys', 'query'], + 'parameters': [{ + 'name': 'location', + 'type': 'string', + 'description': 'location', + 'required': True + }], + 'reasoning_mode': 'prompt', + 'instruction': '{{#sys.query#}}', + 'memory': None, + } + } + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider='openai', model='gpt-3.5-turbo-instruct', mode='completion', credentials={ + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + ) + db.session.close = MagicMock() + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather in SF', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs.get('location') == '' + assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' + assert len(result.process_data.get('prompts')) == 1 + assert 'SF' in result.process_data.get('prompts')[0].get('text') + +def test_extract_json_response(): + """ + Test extract json response. + """ + + node = ParameterExtractorNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'parameter-extractor', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5-turbo-instruct', + 'mode': 'completion', + 'completion_params': {} + }, + 'query': ['sys', 'query'], + 'parameters': [{ + 'name': 'location', + 'type': 'string', + 'description': 'location', + 'required': True + }], + 'reasoning_mode': 'prompt', + 'instruction': '{{#sys.query#}}', + 'memory': None, + } + } + ) + + result = node._extract_complete_json_response(""" + uwu{ovo} + { + "location": "kawaii" + } + hello world. + """) + + assert result['location'] == 'kawaii' \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 4a31334056..02999bf0a2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,5 +1,6 @@ import pytest +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode @@ -15,6 +16,7 @@ def test_execute_code(setup_code_executor_mock): app_id='1', workflow_id='1', user_id='1', + invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.END_USER, config={ 'id': '1', diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 4bbd4ccee7..fffd074457 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,5 +1,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.tool.tool_node import ToolNode from models.workflow import WorkflowNodeExecutionStatus @@ -13,7 +14,8 @@ def test_tool_variable_invoke(): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, config={ 'id': '1', 'data': { @@ -51,7 +53,8 @@ def test_tool_mixed_invoke(): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, config={ 'id': '1', 'data': { diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py new file mode 100644 index 0000000000..9de268d762 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -0,0 +1,77 @@ +from unittest.mock import MagicMock + +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.entities.provider_configuration import ProviderModelBundle +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from models.model import Conversation + + +def test_get_prompt(): + prompt_messages = [ + SystemPromptMessage(content='System Template'), + UserPromptMessage(content='User Query'), + ] + history_messages = [ + SystemPromptMessage(content='System Prompt 1'), + UserPromptMessage(content='User Prompt 1'), + AssistantPromptMessage(content='Assistant Thought 1'), + ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'), + ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'), + SystemPromptMessage(content='System Prompt 2'), + UserPromptMessage(content='User Prompt 2'), + AssistantPromptMessage(content='Assistant Thought 2'), + ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'), + ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'), + UserPromptMessage(content='User Prompt 3'), + AssistantPromptMessage(content='Assistant Thought 3'), + ] + + # use message number instead of token for testing + def side_effect_get_num_tokens(*args): + return len(args[2]) + large_language_model_mock = MagicMock(spec=LargeLanguageModel) + large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens) + + provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) + provider_model_bundle_mock.model_type_instance = large_language_model_mock + + model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_mock.model = 'openai' + model_config_mock.credentials = {} + model_config_mock.provider_model_bundle = provider_model_bundle_mock + + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + transform = AgentHistoryPromptTransform( + model_config=model_config_mock, + prompt_messages=prompt_messages, + history_messages=history_messages, + memory=memory + ) + + max_token_limit = 5 + transform._calculate_rest_token = MagicMock(return_value=max_token_limit) + result = transform.get_prompt() + + assert len(result) <= max_token_limit + assert len(result) == 4 + + max_token_limit = 20 + transform._calculate_rest_token = MagicMock(return_value=max_token_limit) + result = transform.get_prompt() + + assert len(result) <= max_token_limit + assert len(result) == 12 diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 40f5be8af9..2bcc6f4292 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -1,9 +1,10 @@ from unittest.mock import MagicMock from core.app.app_config.entities import ModelConfigEntity -from core.entities.provider_configuration import ProviderModelBundle +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.model_runtime.entities.message_entities import UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule +from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_transform import PromptTransform @@ -22,8 +23,16 @@ def test__calculate_rest_token(): large_language_model_mock = MagicMock(spec=LargeLanguageModel) large_language_model_mock.get_num_tokens.return_value = 6 + provider_mock = MagicMock(spec=ProviderEntity) + provider_mock.provider = 'openai' + + provider_configuration_mock = MagicMock(spec=ProviderConfiguration) + provider_configuration_mock.provider = provider_mock + provider_configuration_mock.model_settings = None + provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) provider_model_bundle_mock.model_type_instance = large_language_model_mock + provider_model_bundle_mock.configuration = provider_configuration_mock model_config_mock = MagicMock(spec=ModelConfigEntity) model_config_mock.model = 'gpt-4' diff --git a/api/tests/unit_tests/core/rag/extractor/__init__.py b/api/tests/unit_tests/core/rag/extractor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py new file mode 100644 index 0000000000..b231fe479b --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -0,0 +1,102 @@ +from unittest import mock + +from core.rag.extractor import notion_extractor + +user_id = "user1" +database_id = "database1" +page_id = "page1" + + +extractor = notion_extractor.NotionExtractor( + notion_workspace_id='x', + notion_obj_id='x', + notion_page_type='page', + tenant_id='x', + notion_access_token='x') + + +def _generate_page(page_title: str): + return { + "object": "page", + "id": page_id, + "properties": { + "Page": { + "type": "title", + "title": [ + { + "type": "text", + "text": {"content": page_title}, + "plain_text": page_title + } + ] + } + } + } + + +def _generate_block(block_id: str, block_type: str, block_text: str): + return { + "object": "block", + "id": block_id, + "parent": { + "type": "page_id", + "page_id": page_id + }, + "type": block_type, + "has_children": False, + block_type: { + "rich_text": [ + { + "type": "text", + "text": {"content": block_text}, + "plain_text": block_text, + }] + } + } + + +def _mock_response(data): + response = mock.Mock() + response.status_code = 200 + response.json.return_value = data + return response + + +def _remove_multiple_new_lines(text): + while '\n\n' in text: + text = text.replace("\n\n", "\n") + return text.strip() + + +def test_notion_page(mocker): + texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] + mocked_notion_page = { + "object": "list", + "results": [ + _generate_block("b1", "heading_1", texts[0]), + _generate_block("b2", "heading_2", texts[1]), + _generate_block("b3", "paragraph", texts[2]), + _generate_block("b4", "heading_3", texts[3]) + ], + "next_cursor": None + } + mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page)) + + page_docs = extractor._load_data_as_documents(page_id, "page") + assert len(page_docs) == 1 + content = _remove_multiple_new_lines(page_docs[0].page_content) + assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1' + + +def test_notion_database(mocker): + page_title_list = ["page1", "page2", "page3"] + mocked_notion_database = { + "object": "list", + "results": [_generate_page(i) for i in page_title_list], + "next_cursor": None + } + mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database)) + database_docs = extractor._load_data_as_documents(database_id, "database") + assert len(database_docs) == 1 + content = _remove_multiple_new_lines(database_docs[0].page_content) + assert content == '\n'.join([f'Page:{i}' for i in page_title_list]) diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py new file mode 100644 index 0000000000..3024a54a4d --- /dev/null +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -0,0 +1,77 @@ +from unittest.mock import MagicMock + +import pytest + +from core.entities.provider_entities import ModelLoadBalancingConfiguration +from core.model_manager import LBModelManager +from core.model_runtime.entities.model_entities import ModelType + + +@pytest.fixture +def lb_model_manager(): + load_balancing_configs = [ + ModelLoadBalancingConfiguration( + id='id1', + name='__inherit__', + credentials={} + ), + ModelLoadBalancingConfiguration( + id='id2', + name='first', + credentials={"openai_api_key": "fake_key"} + ), + ModelLoadBalancingConfiguration( + id='id3', + name='second', + credentials={"openai_api_key": "fake_key"} + ) + ] + + lb_model_manager = LBModelManager( + tenant_id='tenant_id', + provider='openai', + model_type=ModelType.LLM, + model='gpt-4', + load_balancing_configs=load_balancing_configs, + managed_credentials={"openai_api_key": "fake_key"} + ) + + lb_model_manager.cooldown = MagicMock(return_value=None) + + def is_cooldown(config: ModelLoadBalancingConfiguration): + if config.id == 'id1': + return True + + return False + + lb_model_manager.in_cooldown = MagicMock(side_effect=is_cooldown) + + return lb_model_manager + + +def test_lb_model_manager_fetch_next(mocker, lb_model_manager): + assert len(lb_model_manager._load_balancing_configs) == 3 + + config1 = lb_model_manager._load_balancing_configs[0] + config2 = lb_model_manager._load_balancing_configs[1] + config3 = lb_model_manager._load_balancing_configs[2] + + assert lb_model_manager.in_cooldown(config1) is True + assert lb_model_manager.in_cooldown(config2) is False + assert lb_model_manager.in_cooldown(config3) is False + + start_index = 0 + def incr(key): + nonlocal start_index + start_index += 1 + return start_index + + mocker.patch('redis.Redis.incr', side_effect=incr) + mocker.patch('redis.Redis.set', return_value=None) + mocker.patch('redis.Redis.expire', return_value=None) + + config = lb_model_manager.fetch_next() + assert config == config2 + + config = lb_model_manager.fetch_next() + assert config == config3 diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py new file mode 100644 index 0000000000..072b6f100f --- /dev/null +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -0,0 +1,183 @@ +from core.entities.provider_entities import ModelSettings +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers import model_provider_factory +from core.provider_manager import ProviderManager +from models.provider import LoadBalancingModelConfig, ProviderModelSetting + + +def test__to_model_settings(mocker): + # Get all provider entities + provider_entities = model_provider_factory.get_providers() + + provider_entity = None + for provider in provider_entities: + if provider.provider == 'openai': + provider_entity = provider + + # Mocking the inputs + provider_model_settings = [ProviderModelSetting( + id='id', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + enabled=True, + load_balancing_enabled=True + )] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id='id1', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + name='__inherit__', + encrypted_config=None, + enabled=True + ), + LoadBalancingModelConfig( + id='id2', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + name='first', + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True + ) + ] + + mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings( + provider_entity, + provider_model_settings, + load_balancing_model_configs + ) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == 'gpt-4' + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 2 + assert result[0].load_balancing_configs[0].name == '__inherit__' + assert result[0].load_balancing_configs[1].name == 'first' + + +def test__to_model_settings_only_one_lb(mocker): + # Get all provider entities + provider_entities = model_provider_factory.get_providers() + + provider_entity = None + for provider in provider_entities: + if provider.provider == 'openai': + provider_entity = provider + + # Mocking the inputs + provider_model_settings = [ProviderModelSetting( + id='id', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + enabled=True, + load_balancing_enabled=True + )] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id='id1', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + name='__inherit__', + encrypted_config=None, + enabled=True + ) + ] + + mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings( + provider_entity, + provider_model_settings, + load_balancing_model_configs + ) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == 'gpt-4' + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0 + + +def test__to_model_settings_lb_disabled(mocker): + # Get all provider entities + provider_entities = model_provider_factory.get_providers() + + provider_entity = None + for provider in provider_entities: + if provider.provider == 'openai': + provider_entity = provider + + # Mocking the inputs + provider_model_settings = [ProviderModelSetting( + id='id', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + enabled=True, + load_balancing_enabled=False + )] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id='id1', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + name='__inherit__', + encrypted_config=None, + enabled=True + ), + LoadBalancingModelConfig( + id='id2', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + name='first', + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True + ) + ] + + mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings( + provider_entity, + provider_model_settings, + load_balancing_model_configs + ) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == 'gpt-4' + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0 diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py new file mode 100644 index 0000000000..9addeeadca --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py @@ -0,0 +1,56 @@ +import pytest + +from core.tools.entities.tool_entities import ToolParameter +from core.tools.utils.tool_parameter_converter import ToolParameterConverter + + +def test_get_parameter_type(): + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == 'string' + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == 'string' + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == 'boolean' + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == 'number' + with pytest.raises(ValueError): + ToolParameterConverter.get_parameter_type('unsupported_type') + + +def test_cast_parameter_by_type(): + # string + assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.STRING) == 'test' + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == '1' + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == '1.0' + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == '' + + # secret input + assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SECRET_INPUT) == 'test' + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == '1' + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == '1.0' + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == '' + + # select + assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SELECT) == 'test' + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == '1' + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == '1.0' + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == '' + + # boolean + true_values = [True, 'True', 'true', '1', 'YES', 'Yes', 'yes', 'y', 'something'] + for value in true_values: + assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True + + false_values = [False, 'False', 'false', '0', 'NO', 'No', 'no', 'n', None, ''] + for value in false_values: + assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False + + # number + assert ToolParameterConverter.cast_parameter_by_type('1', ToolParameter.ToolParameterType.NUMBER) == 1 + assert ToolParameterConverter.cast_parameter_by_type('1.0', ToolParameter.ToolParameterType.NUMBER) == 1.0 + assert ToolParameterConverter.cast_parameter_by_type('-1.0', ToolParameter.ToolParameterType.NUMBER) == -1.0 + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1 + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0 + assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0 + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None + + # unknown + assert ToolParameterConverter.cast_parameter_by_type('1', 'unknown_type') == '1' + assert ToolParameterConverter.cast_parameter_by_type(1, 'unknown_type') == '1' + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index cf21401eb2..102711b4b6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.answer.answer_node import AnswerNode @@ -15,6 +16,7 @@ def test_execute_answer(): workflow_id='1', user_id='1', user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, config={ 'id': 'answer', 'data': { diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 99413540c5..6860b2fd97 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import UserFrom @@ -15,6 +16,7 @@ def test_execute_if_else_result_true(): workflow_id='1', user_id='1', user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, config={ 'id': 'if-else', 'data': { @@ -155,6 +157,7 @@ def test_execute_if_else_result_false(): workflow_id='1', user_id='1', user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, config={ 'id': 'if-else', 'data': { diff --git a/api/tests/unit_tests/libs/test_pandas.py b/api/tests/unit_tests/libs/test_pandas.py new file mode 100644 index 0000000000..bbc372ed61 --- /dev/null +++ b/api/tests/unit_tests/libs/test_pandas.py @@ -0,0 +1,62 @@ +import pandas as pd + + +def test_pandas_csv(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + data = {'col1': [1, 2.2, -3.3, 4.0, 5], + 'col2': ['A', 'B', 'C', 'D', 'E']} + df1 = pd.DataFrame(data) + + # write to csv file + csv_file_path = tmp_path.joinpath('example.csv') + df1.to_csv(csv_file_path, index=False) + + # read from csv file + df2 = pd.read_csv(csv_file_path, on_bad_lines='skip') + assert df2[df2.columns[0]].to_list() == data['col1'] + assert df2[df2.columns[1]].to_list() == data['col2'] + + +def test_pandas_xlsx(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + data = {'col1': [1, 2.2, -3.3, 4.0, 5], + 'col2': ['A', 'B', 'C', 'D', 'E']} + df1 = pd.DataFrame(data) + + # write to xlsx file + xlsx_file_path = tmp_path.joinpath('example.xlsx') + df1.to_excel(xlsx_file_path, index=False) + + # read from xlsx file + df2 = pd.read_excel(xlsx_file_path) + assert df2[df2.columns[0]].to_list() == data['col1'] + assert df2[df2.columns[1]].to_list() == data['col2'] + + +def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + data1 = {'col1': [1, 2, 3, 4, 5], + 'col2': ['A', 'B', 'C', 'D', 'E']} + df1 = pd.DataFrame(data1) + + data2 = {'col1': [6, 7, 8, 9, 10], + 'col2': ['F', 'G', 'H', 'I', 'J']} + df2 = pd.DataFrame(data2) + + # write to xlsx file with sheets + xlsx_file_path = tmp_path.joinpath('example_with_sheets.xlsx') + sheet1 = 'Sheet1' + sheet2 = 'Sheet2' + with pd.ExcelWriter(xlsx_file_path) as excel_writer: + df1.to_excel(excel_writer, sheet_name=sheet1, index=False) + df2.to_excel(excel_writer, sheet_name=sheet2, index=False) + + # read from xlsx file with sheets + with pd.ExcelFile(xlsx_file_path) as excel_file: + df1 = pd.read_excel(excel_file, sheet_name=sheet1) + assert df1[df1.columns[0]].to_list() == data1['col1'] + assert df1[df1.columns[1]].to_list() == data1['col2'] + + df2 = pd.read_excel(excel_file, sheet_name=sheet2) + assert df2[df2.columns[0]].to_list() == data2['col1'] + assert df2[df2.columns[1]].to_list() == data2['col2'] diff --git a/api/tests/unit_tests/libs/test_yarl.py b/api/tests/unit_tests/libs/test_yarl.py new file mode 100644 index 0000000000..75a5344126 --- /dev/null +++ b/api/tests/unit_tests/libs/test_yarl.py @@ -0,0 +1,23 @@ +import pytest +from yarl import URL + + +def test_yarl_urls(): + expected_1 = 'https://dify.ai/api' + assert str(URL('https://dify.ai') / 'api') == expected_1 + assert str(URL('https://dify.ai/') / 'api') == expected_1 + + expected_2 = 'http://dify.ai:12345/api' + assert str(URL('http://dify.ai:12345') / 'api') == expected_2 + assert str(URL('http://dify.ai:12345/') / 'api') == expected_2 + + expected_3 = 'https://dify.ai/api/v1' + assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3 + assert str(URL('https://dify.ai') / 'api/v1') == expected_3 + assert str(URL('https://dify.ai/') / 'api/v1') == expected_3 + assert str(URL('https://dify.ai/api') / 'v1') == expected_3 + assert str(URL('https://dify.ai/api/') / 'v1') == expected_3 + + with pytest.raises(ValueError) as e1: + str(URL('https://dify.ai') / '/api') + assert str(e1.value) == "Appending path '/api' starting from slash is forbidden" diff --git a/api/tests/unit_tests/utils/__init__.py b/api/tests/unit_tests/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py new file mode 100644 index 0000000000..c389461454 --- /dev/null +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -0,0 +1,34 @@ +from textwrap import dedent + +import pytest + +from core.helper.position_helper import get_position_map + + +@pytest.fixture +def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + tmp_path.joinpath("example_positions.yaml").write_text(dedent( + """\ + - first + - second + # - commented + - third + + - 9999999999999 + - forth + """)) + return str(tmp_path) + + +def test_position_helper(prepare_example_positions_yaml): + position_map = get_position_map( + folder_path=prepare_example_positions_yaml, + file_name='example_positions.yaml') + assert len(position_map) == 4 + assert position_map == { + 'first': 0, + 'second': 1, + 'third': 2, + 'forth': 3, + } diff --git a/api/tests/unit_tests/utils/yaml/__init__.py b/api/tests/unit_tests/utils/yaml/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py new file mode 100644 index 0000000000..446588cde1 --- /dev/null +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -0,0 +1,74 @@ +from textwrap import dedent + +import pytest +from yaml import YAMLError + +from core.tools.utils.yaml_utils import load_yaml_file + +EXAMPLE_YAML_FILE = 'example_yaml.yaml' +INVALID_YAML_FILE = 'invalid_yaml.yaml' +NON_EXISTING_YAML_FILE = 'non_existing_file.yaml' + + +@pytest.fixture +def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE) + file_path.write_text(dedent( + """\ + address: + city: Example City + country: Example Country + age: 30 + gender: male + languages: + - Python + - Java + - C++ + empty_key: + """)) + return str(file_path) + + +@pytest.fixture +def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + file_path = tmp_path.joinpath(INVALID_YAML_FILE) + file_path.write_text(dedent( + """\ + address: + city: Example City + country: Example Country + age: 30 + gender: male + languages: + - Python + - Java + - C++ + """)) + return str(file_path) + + +def test_load_yaml_non_existing_file(): + assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} + assert load_yaml_file(file_path='') == {} + + +def test_load_valid_yaml_file(prepare_example_yaml_file): + yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) + assert len(yaml_data) > 0 + assert yaml_data['age'] == 30 + assert yaml_data['gender'] == 'male' + assert yaml_data['address']['city'] == 'Example City' + assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'} + assert yaml_data.get('empty_key') is None + assert yaml_data.get('non_existed_key') is None + + +def test_load_invalid_yaml_file(prepare_invalid_yaml_file): + # yaml syntax error + with pytest.raises(YAMLError): + load_yaml_file(file_path=prepare_invalid_yaml_file) + + # ignore error + assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {} diff --git a/dev/reformat b/dev/reformat index ebee1efb40..844bfbef58 100755 --- a/dev/reformat +++ b/dev/reformat @@ -9,7 +9,7 @@ if ! command -v ruff &> /dev/null; then fi # run ruff linter -ruff check --fix ./api +ruff check --fix --preview ./api # env files linting relies on `dotenv-linter` in path if ! command -v dotenv-linter &> /dev/null; then diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 6bf45da9e0..38760901b1 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -53,20 +53,39 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.1.0 + image: langgenius/dify-sandbox:0.2.1 restart: always - cap_add: - # Why is sys_admin permission needed? - # https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-sys_admin-permission-needed - - SYS_ADMIN environment: # The DifySandbox configurations + # Make sure you are changing this key for your deployment with a strong key. + # You can generate a strong key using `openssl rand -base64 42`. API_KEY: dify-sandbox GIN_MODE: 'release' WORKER_TIMEOUT: 15 - ports: - - "8194:8194" + ENABLE_NETWORK: 'true' + HTTP_PROXY: 'http://ssrf_proxy:3128' + HTTPS_PROXY: 'http://ssrf_proxy:3128' + SANDBOX_PORT: 8194 + volumes: + - ./volumes/sandbox/dependencies:/dependencies + networks: + - ssrf_proxy_network + # ssrf_proxy server + # for more information, please refer to + # https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-ssrf_proxy-needed + ssrf_proxy: + image: ubuntu/squid:latest + restart: always + ports: + - "3128:3128" + - "8194:8194" + volumes: + # pls clearly modify the squid.conf file to fit your network environment. + - ./volumes/ssrf_proxy/squid.conf:/etc/squid/squid.conf + networks: + - ssrf_proxy_network + - default # Qdrant vector store. # uncomment to use qdrant as vector store. # (if uncommented, you need to comment out the weaviate service above, @@ -81,3 +100,10 @@ services: # ports: # - "6333:6333" # - "6334:6334" + + +networks: + # create a network between sandbox, api and ssrf_proxy, and can not access outside. + ssrf_proxy_network: + driver: bridge + internal: true diff --git a/docker/docker-compose.pgvector.yaml b/docker/docker-compose.pgvector.yaml new file mode 100644 index 0000000000..b584880abf --- /dev/null +++ b/docker/docker-compose.pgvector.yaml @@ -0,0 +1,24 @@ +version: '3' +services: + # Qdrant vector store. + pgvector: + image: pgvector/pgvector:pg16 + restart: always + environment: + PGUSER: postgres + # The password for the default postgres user. + POSTGRES_PASSWORD: difyai123456 + # The name of the default postgres database. + POSTGRES_DB: dify + # postgres data directory + PGDATA: /var/lib/postgresql/data/pgdata + volumes: + - ./volumes/pgvector/data:/var/lib/postgresql/data + # uncomment to expose db(postgresql) port to host + ports: + - "5433:5432" + healthcheck: + test: [ "CMD", "pg_isready" ] + interval: 1s + timeout: 3s + retries: 30 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 706f7d70e1..c3b5430514 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -2,13 +2,15 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.6.6 + image: langgenius/dify-api:0.6.10 restart: always environment: # Startup mode, 'api' starts the API server. MODE: api # The log level for the application. Supported values are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` LOG_LEVEL: INFO + # enable DEBUG mode to output more logs + # DEBUG : true # A secret key that is used for securely signing the session cookie and encrypting sensitive information on the database. You can generate a strong key using `openssl rand -base64 42`. SECRET_KEY: sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U # The base URL of console application web frontend, refers to the Console base URL of WEB service if console domain is @@ -34,6 +36,9 @@ services: # used to display File preview or download Url to the front-end or as Multi-model inputs; # Url is signed and has expiration time. FILES_URL: '' + # File Access Time specifies a time interval in seconds for the file to be accessed. + # The default value is 300 seconds. + FILES_ACCESS_TIMEOUT: 300 # When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed. MIGRATION_ENABLED: 'true' # The configurations of postgres database connection. @@ -88,6 +93,7 @@ services: AZURE_BLOB_ACCOUNT_URL: 'https://.blob.core.windows.net' # The Google storage configurations, only available when STORAGE_TYPE is `google-storage`. GOOGLE_STORAGE_BUCKET_NAME: 'yout-bucket-name' + # if you want to use Application Default Credentials, you can leave GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 empty. GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: 'your-google-service-account-json-base64-string' # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`. VECTOR_STORE: weaviate @@ -122,15 +128,28 @@ services: RELYT_USER: postgres RELYT_PASSWORD: difyai123456 RELYT_DATABASE: postgres + # pgvector configurations + PGVECTOR_HOST: pgvector + PGVECTOR_PORT: 5432 + PGVECTOR_USER: postgres + PGVECTOR_PASSWORD: difyai123456 + PGVECTOR_DATABASE: dify + # tidb vector configurations + TIDB_VECTOR_HOST: tidb + TIDB_VECTOR_PORT: 4000 + TIDB_VECTOR_USER: xxx.root + TIDB_VECTOR_PASSWORD: xxxxxx + TIDB_VECTOR_DATABASE: dify # Mail configuration, support: resend, smtp MAIL_TYPE: '' # default send from email address, if not specified MAIL_DEFAULT_SEND_FROM: 'YOUR EMAIL FROM (eg: no-reply )' SMTP_SERVER: '' - SMTP_PORT: 587 + SMTP_PORT: 465 SMTP_USERNAME: '' SMTP_PASSWORD: '' SMTP_USE_TLS: 'true' + SMTP_OPPORTUNISTIC_TLS: 'false' # the api-key for resend (https://resend.com) RESEND_API_KEY: '' RESEND_API_URL: https://api.resend.com @@ -155,6 +174,11 @@ services: CODE_MAX_STRING_ARRAY_LENGTH: 30 CODE_MAX_OBJECT_ARRAY_LENGTH: 30 CODE_MAX_NUMBER_ARRAY_LENGTH: 1000 + # SSRF Proxy server + SSRF_PROXY_HTTP_URL: 'http://ssrf_proxy:3128' + SSRF_PROXY_HTTPS_URL: 'http://ssrf_proxy:3128' + # Indexing configuration + INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: 1000 depends_on: - db - redis @@ -164,13 +188,17 @@ services: # uncomment to expose dify-api port to host # ports: # - "5001:5001" + networks: + - ssrf_proxy_network + - default # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.6.6 + image: langgenius/dify-api:0.6.10 restart: always environment: + CONSOLE_WEB_URL: '' # Startup mode, 'worker' starts the Celery worker for processing the queue. MODE: worker @@ -197,7 +225,7 @@ services: REDIS_USE_SSL: 'false' # The configurations of celery broker. CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1 - # The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob`, Default: `local` + # The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob` and `google-storage`, Default: `local` STORAGE_TYPE: local STORAGE_LOCAL_PATH: storage # The S3 storage configurations, only available when STORAGE_TYPE is `s3`. @@ -211,7 +239,11 @@ services: AZURE_BLOB_ACCOUNT_KEY: 'difyai' AZURE_BLOB_CONTAINER_NAME: 'difyai-container' AZURE_BLOB_ACCOUNT_URL: 'https://.blob.core.windows.net' - # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`. + # The Google storage configurations, only available when STORAGE_TYPE is `google-storage`. + GOOGLE_STORAGE_BUCKET_NAME: 'yout-bucket-name' + # if you want to use Application Default Credentials, you can leave GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 empty. + GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: 'your-google-service-account-json-base64-string' + # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`, `pgvector`. VECTOR_STORE: weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. WEAVIATE_ENDPOINT: http://weaviate:8080 @@ -221,7 +253,7 @@ services: QDRANT_URL: http://qdrant:6333 # The Qdrant API key. QDRANT_API_KEY: difyai123456 - # The Qdrant clinet timeout setting. + # The Qdrant client timeout setting. QDRANT_CLIENT_TIMEOUT: 20 # The Qdrant client enable gRPC mode. QDRANT_GRPC_ENABLED: 'false' @@ -242,6 +274,12 @@ services: MAIL_TYPE: '' # default send from email address, if not specified MAIL_DEFAULT_SEND_FROM: 'YOUR EMAIL FROM (eg: no-reply )' + SMTP_SERVER: '' + SMTP_PORT: 465 + SMTP_USERNAME: '' + SMTP_PASSWORD: '' + SMTP_USE_TLS: 'true' + SMTP_OPPORTUNISTIC_TLS: 'false' # the api-key for resend (https://resend.com) RESEND_API_KEY: '' RESEND_API_URL: https://api.resend.com @@ -251,21 +289,38 @@ services: RELYT_USER: postgres RELYT_PASSWORD: difyai123456 RELYT_DATABASE: postgres + # pgvector configurations + PGVECTOR_HOST: pgvector + PGVECTOR_PORT: 5432 + PGVECTOR_USER: postgres + PGVECTOR_PASSWORD: difyai123456 + PGVECTOR_DATABASE: dify + # tidb vector configurations + TIDB_VECTOR_HOST: tidb + TIDB_VECTOR_PORT: 4000 + TIDB_VECTOR_USER: xxx.root + TIDB_VECTOR_PASSWORD: xxxxxx + TIDB_VECTOR_DATABASE: dify # Notion import configuration, support public and internal NOTION_INTEGRATION_TYPE: public NOTION_CLIENT_SECRET: you-client-secret NOTION_CLIENT_ID: you-client-id NOTION_INTERNAL_SECRET: you-internal-secret + # Indexing configuration + INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: 1000 depends_on: - db - redis volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + networks: + - ssrf_proxy_network + - default # Frontend web application. web: - image: langgenius/dify-web:0.6.6 + image: langgenius/dify-web:0.6.10 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is @@ -346,18 +401,36 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.1.0 + image: langgenius/dify-sandbox:0.2.1 restart: always - cap_add: - # Why is sys_admin permission needed? - # https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-sys_admin-permission-needed - - SYS_ADMIN environment: # The DifySandbox configurations + # Make sure you are changing this key for your deployment with a strong key. + # You can generate a strong key using `openssl rand -base64 42`. API_KEY: dify-sandbox - GIN_MODE: release + GIN_MODE: 'release' WORKER_TIMEOUT: 15 + ENABLE_NETWORK: 'true' + HTTP_PROXY: 'http://ssrf_proxy:3128' + HTTPS_PROXY: 'http://ssrf_proxy:3128' + SANDBOX_PORT: 8194 + volumes: + - ./volumes/sandbox/dependencies:/dependencies + networks: + - ssrf_proxy_network + # ssrf_proxy server + # for more information, please refer to + # https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-ssrf_proxy-needed + ssrf_proxy: + image: ubuntu/squid:latest + restart: always + volumes: + # pls clearly modify the squid.conf file to fit your network environment. + - ./volumes/ssrf_proxy/squid.conf:/etc/squid/squid.conf + networks: + - ssrf_proxy_network + - default # Qdrant vector store. # uncomment to use qdrant as vector store. # (if uncommented, you need to comment out the weaviate service above, @@ -374,6 +447,31 @@ services: # # - "6333:6333" # # - "6334:6334" + # The pgvector vector database. + # Uncomment to use qdrant as vector store. + # pgvector: + # image: pgvector/pgvector:pg16 + # restart: always + # environment: + # PGUSER: postgres + # # The password for the default postgres user. + # POSTGRES_PASSWORD: difyai123456 + # # The name of the default postgres database. + # POSTGRES_DB: dify + # # postgres data directory + # PGDATA: /var/lib/postgresql/data/pgdata + # volumes: + # - ./volumes/pgvector/data:/var/lib/postgresql/data + # # uncomment to expose db(postgresql) port to host + # # ports: + # # - "5433:5432" + # healthcheck: + # test: [ "CMD", "pg_isready" ] + # interval: 1s + # timeout: 3s + # retries: 30 + + # The nginx reverse proxy. # used for reverse proxying the API service and Web service. nginx: @@ -390,3 +488,8 @@ services: ports: - "80:80" #- "443:443" +networks: + # create a network between sandbox, api and ssrf_proxy, and can not access outside. + ssrf_proxy_network: + driver: bridge + internal: true diff --git a/docker/volumes/sandbox/dependencies/python-requirements.txt b/docker/volumes/sandbox/dependencies/python-requirements.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docker/volumes/ssrf_proxy/squid.conf b/docker/volumes/ssrf_proxy/squid.conf new file mode 100644 index 0000000000..3028bf35c6 --- /dev/null +++ b/docker/volumes/ssrf_proxy/squid.conf @@ -0,0 +1,50 @@ +acl localnet src 0.0.0.1-0.255.255.255 # RFC 1122 "this" network (LAN) +acl localnet src 10.0.0.0/8 # RFC 1918 local private network (LAN) +acl localnet src 100.64.0.0/10 # RFC 6598 shared address space (CGN) +acl localnet src 169.254.0.0/16 # RFC 3927 link-local (directly plugged) machines +acl localnet src 172.16.0.0/12 # RFC 1918 local private network (LAN) +acl localnet src 192.168.0.0/16 # RFC 1918 local private network (LAN) +acl localnet src fc00::/7 # RFC 4193 local private network range +acl localnet src fe80::/10 # RFC 4291 link-local (directly plugged) machines +acl SSL_ports port 443 +acl Safe_ports port 80 # http +acl Safe_ports port 21 # ftp +acl Safe_ports port 443 # https +acl Safe_ports port 70 # gopher +acl Safe_ports port 210 # wais +acl Safe_ports port 1025-65535 # unregistered ports +acl Safe_ports port 280 # http-mgmt +acl Safe_ports port 488 # gss-http +acl Safe_ports port 591 # filemaker +acl Safe_ports port 777 # multiling http +acl CONNECT method CONNECT +http_access deny !Safe_ports +http_access deny CONNECT !SSL_ports +http_access allow localhost manager +http_access deny manager +http_access allow localhost +http_access allow localnet +http_access deny all + +################################## Proxy Server ################################ +http_port 3128 +coredump_dir /var/spool/squid +refresh_pattern ^ftp: 1440 20% 10080 +refresh_pattern ^gopher: 1440 0% 1440 +refresh_pattern -i (/cgi-bin/|\?) 0 0% 0 +refresh_pattern \/(Packages|Sources)(|\.bz2|\.gz|\.xz)$ 0 0% 0 refresh-ims +refresh_pattern \/Release(|\.gpg)$ 0 0% 0 refresh-ims +refresh_pattern \/InRelease$ 0 0% 0 refresh-ims +refresh_pattern \/(Translation-.*)(|\.bz2|\.gz|\.xz)$ 0 0% 0 refresh-ims +refresh_pattern . 0 20% 4320 +logfile_rotate 0 + +# upstream proxy, set to your own upstream proxy IP to avoid SSRF attacks +# cache_peer 172.1.1.1 parent 3128 0 no-query no-digest no-netdb-exchange default + + +################################## Reverse Proxy To Sandbox ################################ +http_port 8194 accel vhost +cache_peer sandbox parent 8194 0 no-query originserver +acl all src all +http_access allow all \ No newline at end of file diff --git a/sdks/nodejs-client/index.d.ts b/sdks/nodejs-client/index.d.ts index cf1d825221..7fdd943f63 100644 --- a/sdks/nodejs-client/index.d.ts +++ b/sdks/nodejs-client/index.d.ts @@ -14,15 +14,6 @@ interface HeaderParams { interface User { } -interface ChatMessageConfig { - inputs: any; - query: string; - user: User; - stream?: boolean; - conversation_id?: string | null; - files?: File[] | null; -} - export declare class DifyClient { constructor(apiKey: string, baseUrl?: string); @@ -54,7 +45,14 @@ export declare class CompletionClient extends DifyClient { } export declare class ChatClient extends DifyClient { - createChatMessage(config: ChatMessageConfig): Promise; + createChatMessage( + inputs: any, + query: string, + user: User, + stream?: boolean, + conversation_id?: string | null, + files?: File[] | null + ): Promise; getConversationMessages( user: User, diff --git a/sdks/nodejs-client/index.js b/sdks/nodejs-client/index.js index 127d62cf87..93491fa16b 100644 --- a/sdks/nodejs-client/index.js +++ b/sdks/nodejs-client/index.js @@ -85,7 +85,7 @@ export class DifyClient { response = await axios({ method, url, - data, + ...(method !== "GET" && { data }), params, headers, responseType: "json", diff --git a/sdks/nodejs-client/index.test.js b/sdks/nodejs-client/index.test.js index e08b8e82af..f300b16fc9 100644 --- a/sdks/nodejs-client/index.test.js +++ b/sdks/nodejs-client/index.test.js @@ -42,7 +42,6 @@ describe('Send Requests', () => { expect(axios).toHaveBeenCalledWith({ method, url: `${BASE_URL}${endpoint}`, - data: null, params: null, headers: { Authorization: `Bearer ${difyClient.apiKey}`, diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index 83b2f8a4c0..cc27c5e0c0 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -1,6 +1,6 @@ { "name": "dify-client", - "version": "2.3.1", + "version": "2.3.2", "description": "This is the Node.js SDK for the Dify.AI API, which allows you to easily integrate Dify.AI into your Node.js applications.", "main": "index.js", "type": "module", diff --git a/web/.husky/pre-commit b/web/.husky/pre-commit index 4bc7fb77ab..8d1ad1d09f 100755 --- a/web/.husky/pre-commit +++ b/web/.husky/pre-commit @@ -31,7 +31,7 @@ if $api_modified; then pip install ruff fi - ruff check ./api || status=$? + ruff check --preview ./api || status=$? status=${status:-0} diff --git a/web/.vscode/settings.example.json b/web/.vscode/settings.example.json index 34b49e2708..6162d021d0 100644 --- a/web/.vscode/settings.example.json +++ b/web/.vscode/settings.example.json @@ -2,7 +2,7 @@ "prettier.enable": false, "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.fixAll.eslint": true + "source.fixAll.eslint": "explicit" }, "eslint.format.enable": true, "[python]": { diff --git a/web/app/(commonLayout)/datasets/NewDatasetCard.tsx b/web/app/(commonLayout)/datasets/NewDatasetCard.tsx index 5ce79adadc..f3e34ff7e2 100644 --- a/web/app/(commonLayout)/datasets/NewDatasetCard.tsx +++ b/web/app/(commonLayout)/datasets/NewDatasetCard.tsx @@ -9,7 +9,7 @@ const CreateAppCard = forwardRef((_, ref) => { return ( -
    +
    diff --git a/web/app/(commonLayout)/tools/custom/page.tsx b/web/app/(commonLayout)/tools/custom/page.tsx deleted file mode 100644 index c666bf347c..0000000000 --- a/web/app/(commonLayout)/tools/custom/page.tsx +++ /dev/null @@ -1,10 +0,0 @@ -import React from 'react' - -const Custom = () => { - return ( -
    - Custom -
    - ) -} -export default Custom diff --git a/web/app/(commonLayout)/tools/page.tsx b/web/app/(commonLayout)/tools/page.tsx index 36a76d60d0..066550b3a2 100644 --- a/web/app/(commonLayout)/tools/page.tsx +++ b/web/app/(commonLayout)/tools/page.tsx @@ -2,8 +2,7 @@ import type { FC } from 'react' import { useTranslation } from 'react-i18next' import React, { useEffect } from 'react' -import Tools from '@/app/components/tools' -import { LOC } from '@/app/components/tools/types' +import ToolProviderList from '@/app/components/tools/provider-list' const Layout: FC = () => { const { t } = useTranslation() @@ -12,12 +11,6 @@ const Layout: FC = () => { document.title = `${t('tools.title')} - Dify` }, []) - return ( -
    - -
    - ) + return } export default React.memo(Layout) diff --git a/web/app/(commonLayout)/tools/third-part/page.tsx b/web/app/(commonLayout)/tools/third-part/page.tsx deleted file mode 100644 index d2ae810609..0000000000 --- a/web/app/(commonLayout)/tools/third-part/page.tsx +++ /dev/null @@ -1,10 +0,0 @@ -import React from 'react' - -const ThirdPart = () => { - return ( -
    - Third part -
    - ) -} -export default ThirdPart diff --git a/web/app/(shareLayout)/chat/[token]/page.tsx b/web/app/(shareLayout)/chat/[token]/page.tsx index 6c3fe2b4a4..56b2e0da7d 100644 --- a/web/app/(shareLayout)/chat/[token]/page.tsx +++ b/web/app/(shareLayout)/chat/[token]/page.tsx @@ -1,5 +1,4 @@ 'use client' - import type { FC } from 'react' import React from 'react' diff --git a/web/app/(shareLayout)/chatbot/[token]/page.tsx b/web/app/(shareLayout)/chatbot/[token]/page.tsx index 8aa182893a..0dc7b07169 100644 --- a/web/app/(shareLayout)/chatbot/[token]/page.tsx +++ b/web/app/(shareLayout)/chatbot/[token]/page.tsx @@ -1,12 +1,87 @@ +'use client' import type { FC } from 'react' -import React from 'react' - +import React, { useEffect } from 'react' +import cn from 'classnames' import type { IMainProps } from '@/app/components/share/chat' import Main from '@/app/components/share/chatbot' +import Loading from '@/app/components/base/loading' +import { fetchSystemFeatures } from '@/service/share' +import LogoSite from '@/app/components/base/logo/logo-site' const Chatbot: FC = () => { + const [isSSOEnforced, setIsSSOEnforced] = React.useState(true) + const [loading, setLoading] = React.useState(true) + + useEffect(() => { + fetchSystemFeatures().then((res) => { + setIsSSOEnforced(res.sso_enforced_for_web) + setLoading(false) + }) + }, []) + return ( -
    + <> + { + loading + ? ( +
    +
    + +
    +
    + ) + : ( + <> + {isSSOEnforced + ? ( +
    +
    +
    + +
    + +
    +
    +
    +

    + Warning: Chatbot is not available +

    +

    + Because SSO is enforced. Please contact your administrator. +

    +
    +
    +
    +
    +
    + ) + :
    + } + + )} + ) } diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx new file mode 100644 index 0000000000..5306854638 --- /dev/null +++ b/web/app/(shareLayout)/webapp-signin/page.tsx @@ -0,0 +1,154 @@ +'use client' +import cn from 'classnames' +import { useRouter, useSearchParams } from 'next/navigation' +import type { FC } from 'react' +import React, { useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Toast from '@/app/components/base/toast' +import Button from '@/app/components/base/button' +import { fetchSystemFeatures, fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share' +import LogoSite from '@/app/components/base/logo/logo-site' +import { setAccessToken } from '@/app/components/share/utils' + +const WebSSOForm: FC = () => { + const searchParams = useSearchParams() + + const redirectUrl = searchParams.get('redirect_url') + const tokenFromUrl = searchParams.get('web_sso_token') + const message = searchParams.get('message') + + const router = useRouter() + const { t } = useTranslation() + + const [isLoading, setIsLoading] = useState(false) + const [protocol, setProtocol] = useState('') + + useEffect(() => { + const fetchFeaturesAndSetToken = async () => { + await fetchSystemFeatures().then((res) => { + setProtocol(res.sso_enforced_for_web_protocol) + }) + + // Callback from SSO, process token and redirect + if (tokenFromUrl && redirectUrl) { + const appCode = redirectUrl.split('/').pop() + if (!appCode) { + Toast.notify({ + type: 'error', + message: 'redirect url is invalid. App code is not found.', + }) + return + } + + await setAccessToken(appCode, tokenFromUrl) + router.push(redirectUrl) + } + } + + fetchFeaturesAndSetToken() + + if (message) { + Toast.notify({ + type: 'error', + message, + }) + } + }, []) + + const handleSSOLogin = () => { + setIsLoading(true) + + if (!redirectUrl) { + Toast.notify({ + type: 'error', + message: 'redirect url is not found.', + }) + setIsLoading(false) + return + } + + const appCode = redirectUrl.split('/').pop() + if (!appCode) { + Toast.notify({ + type: 'error', + message: 'redirect url is invalid. App code is not found.', + }) + return + } + + if (protocol === 'saml') { + fetchWebSAMLSSOUrl(appCode, redirectUrl).then((res) => { + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else if (protocol === 'oidc') { + fetchWebOIDCSSOUrl(appCode, redirectUrl).then((res) => { + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else if (protocol === 'oauth2') { + fetchWebOAuth2SSOUrl(appCode, redirectUrl).then((res) => { + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else { + Toast.notify({ + type: 'error', + message: 'sso protocol is not supported.', + }) + setIsLoading(false) + } + } + + return ( +
    +
    +
    + +
    + +
    +
    +
    +

    {t('login.pageTitle')}

    +
    +
    + +
    +
    +
    +
    +
    + ) +} + +export default React.memo(WebSSOForm) diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx index aba3b6324c..18cc0a0f9c 100644 --- a/web/app/components/app/annotation/header-opts/index.tsx +++ b/web/app/components/app/annotation/header-opts/index.tsx @@ -150,7 +150,7 @@ const HeaderOptions: FC = ({ s.actionIconWrapper, ) } - className={'!w-[154px] h-fit !z-20'} + className={'!w-[155px] h-fit !z-20'} popupClassName='!w-full !overflow-visible' manualClose /> diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index fa7b9ed344..7e4c584257 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -23,6 +23,8 @@ import { PlayCircle } from '@/app/components/base/icons/src/vender/line/mediaAnd import { CodeBrowser } from '@/app/components/base/icons/src/vender/line/development' import { LeftIndent02 } from '@/app/components/base/icons/src/vender/line/editor' import { FileText } from '@/app/components/base/icons/src/vender/line/files' +import WorkflowToolConfigureButton from '@/app/components/tools/workflow-tool/configure-button' +import type { InputVar } from '@/app/components/workflow/types' export type AppPublisherProps = { disabled?: boolean @@ -37,6 +39,9 @@ export type AppPublisherProps = { onRestore?: () => Promise | any onToggle?: (state: boolean) => void crossAxisOffset?: number + toolPublished?: boolean + inputs?: InputVar[] + onRefreshData?: () => void } const AppPublisher = ({ @@ -50,6 +55,9 @@ const AppPublisher = ({ onRestore, onToggle, crossAxisOffset = 0, + toolPublished, + inputs, + onRefreshData, }: AppPublisherProps) => { const { t } = useTranslation() const [published, setPublished] = useState(false) @@ -122,7 +130,7 @@ const AppPublisher = ({ -
    +
    {publishedAt ? t('workflow.common.latestPublished') : t('workflow.common.currentDraftUnpublished')} @@ -202,6 +210,23 @@ const AppPublisher = ({ )} }>{t('workflow.common.accessAPIReference')} + {appDetail?.mode === 'workflow' && ( + + )}
    diff --git a/web/app/components/app/chat/answer/index.tsx b/web/app/components/app/chat/answer/index.tsx index 49dd817255..1ba033911a 100644 --- a/web/app/components/app/chat/answer/index.tsx +++ b/web/app/components/app/chat/answer/index.tsx @@ -362,7 +362,7 @@ const Answer: FC = ({ {!item.isOpeningStatement && ( )} {((isShowPromptLog && !isResponding) || (!item.isOpeningStatement && isShowTextToSpeech)) && ( @@ -375,6 +375,7 @@ const Answer: FC = ({
    diff --git a/web/app/components/app/chat/index.tsx b/web/app/components/app/chat/index.tsx index e1fe5589ed..d861ddb2de 100644 --- a/web/app/components/app/chat/index.tsx +++ b/web/app/components/app/chat/index.tsx @@ -67,6 +67,7 @@ export type IChatProps = { visionConfig?: VisionSettings supportAnnotation?: boolean allToolIcons?: Record + customDisclaimer?: string } const Chat: FC = ({ @@ -102,6 +103,7 @@ const Chat: FC = ({ supportAnnotation, onChatListChange, allToolIcons, + customDisclaimer, }) => { const { t } = useTranslation() const { notify } = useContext(ToastContext) @@ -260,11 +262,7 @@ const Chat: FC = ({ return { ...item, content: item.content, - annotation: { - ...(item.annotation || {}), - id: '', - logAnnotation: undefined, // remove log - } as Annotation, + annotation: undefined, } } return item @@ -362,44 +360,46 @@ const Chat: FC = ({
    )} -
    - {visionConfig?.enabled && ( - <> -
    - = visionConfig.number_limits} - /> -
    -
    -
    - -
    - - )} -