diff --git a/.github/actions/setup-uv/action.yml b/.github/actions/setup-uv/action.yml deleted file mode 100644 index 6990f6becf..0000000000 --- a/.github/actions/setup-uv/action.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: Setup UV and Python - -inputs: - python-version: - description: Python version to use and the UV installed with - required: true - default: '3.12' - uv-version: - description: UV version to set up - required: true - default: '0.8.9' - uv-lockfile: - description: Path to the UV lockfile to restore cache from - required: true - default: '' - enable-cache: - required: true - default: true - -runs: - using: composite - steps: - - name: Set up Python ${{ inputs.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ inputs.python-version }} - - - name: Install uv - uses: astral-sh/setup-uv@v6 - with: - version: ${{ inputs.uv-version }} - python-version: ${{ inputs.python-version }} - enable-cache: ${{ inputs.enable-cache }} - cache-dependency-glob: ${{ inputs.uv-lockfile }} diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 63d681e7ed..28ef67a133 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -33,10 +33,11 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: ./.github/actions/setup-uv + uses: astral-sh/setup-uv@v6 with: + enable-cache: true python-version: ${{ matrix.python-version }} - uv-lockfile: api/uv.lock + cache-dependency-glob: api/uv.lock - name: Check UV lockfile run: uv lock --project api --check diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index dada6229db..716c29957c 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -15,7 +15,9 @@ jobs: - uses: actions/checkout@v4 # Use uv to ensure we have the same ruff version in CI and locally. - - uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f + - uses: astral-sh/setup-uv@v6 + with: + python-version: "3.12" - run: | cd api uv sync --dev diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index 5181546b4a..e8ff85e95c 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -25,9 +25,11 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: ./.github/actions/setup-uv + uses: astral-sh/setup-uv@v6 with: - uv-lockfile: api/uv.lock + enable-cache: true + python-version: "3.12" + cache-dependency-glob: api/uv.lock - name: Install dependencies run: uv sync --project api diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 9aad9558b0..8d0ec35ca1 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -36,10 +36,11 @@ jobs: - name: Setup UV and Python if: steps.changed-files.outputs.any_changed == 'true' - uses: ./.github/actions/setup-uv + uses: astral-sh/setup-uv@v6 with: - uv-lockfile: api/uv.lock enable-cache: false + python-version: "3.12" + cache-dependency-glob: api/uv.lock - name: Install dependencies if: steps.changed-files.outputs.any_changed == 'true' diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 912267094b..f2ca09fba2 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -39,10 +39,11 @@ jobs: remove_tool_cache: true - name: Setup UV and Python - uses: ./.github/actions/setup-uv + uses: astral-sh/setup-uv@v6 with: + enable-cache: true python-version: ${{ matrix.python-version }} - uv-lockfile: api/uv.lock + cache-dependency-glob: api/uv.lock - name: Check UV lockfile run: uv lock --project api --check diff --git a/README_CN.md b/README_CN.md index 2949b38867..9aaebf4037 100644 --- a/README_CN.md +++ b/README_CN.md @@ -180,7 +180,7 @@ docker compose up -d ## Contributing -对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 +对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_CN.md)。 同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。 > 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。 diff --git a/README_DE.md b/README_DE.md index a593a12abf..a08fe63d4f 100644 --- a/README_DE.md +++ b/README_DE.md @@ -173,7 +173,7 @@ Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline ## Contributing -Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. +Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_DE.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. > Wir suchen Mitwirkende, die dabei helfen, Dify in weitere Sprachen zu übersetzen – außer Mandarin oder Englisch. Wenn Sie Interesse an einer Mitarbeit haben, lesen Sie bitte die [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) für weitere Informationen und hinterlassen Sie einen Kommentar im `global-users`-Kanal unseres [Discord Community Servers](https://discord.gg/8Tpq4AcN9c). diff --git a/README_ES.md b/README_ES.md index c7a18dc675..d8fdbf54e6 100644 --- a/README_ES.md +++ b/README_ES.md @@ -170,7 +170,7 @@ Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @ ## Contribuir -Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_ES.md). Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en eventos y conferencias. > Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c). diff --git a/README_FR.md b/README_FR.md index 316d50c929..7474ea50c2 100644 --- a/README_FR.md +++ b/README_FR.md @@ -168,7 +168,7 @@ Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart ## Contribuer -Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_FR.md). Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur les réseaux sociaux et lors d'événements et de conférences. > Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c). diff --git a/README_JA.md b/README_JA.md index 785706a88a..a782849f6e 100644 --- a/README_JA.md +++ b/README_JA.md @@ -169,7 +169,7 @@ docker compose up -d ## 貢献 -コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。 +コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_JA.md)を参照してください。 同時に、DifyをSNSやイベント、カンファレンスで共有してサポートしていただけると幸いです。 > Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。 diff --git a/README_KR.md b/README_KR.md index 3b58339e12..ec28cc0f61 100644 --- a/README_KR.md +++ b/README_KR.md @@ -162,7 +162,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 ## 기여 -코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. +코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_KR.md)를 참조하세요. 동시에 Dify를 소셜 미디어와 행사 및 컨퍼런스에 공유하여 지원하는 것을 고려해 주시기 바랍니다. > 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요. diff --git a/README_PT.md b/README_PT.md index ec2e4245f6..da8f354a49 100644 --- a/README_PT.md +++ b/README_PT.md @@ -168,7 +168,7 @@ Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by ## Contribuindo -Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_PT.md). Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em eventos e conferências. > Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c). diff --git a/README_TR.md b/README_TR.md index 510b112e68..21df0d1605 100644 --- a/README_TR.md +++ b/README_TR.md @@ -161,7 +161,7 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter ## Katkıda Bulunma -Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz. +Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TR.md) bakabilirsiniz. Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda paylaşarak desteklemeyi düşünün. > Dify'ı Mandarin veya İngilizce dışındaki dillere çevirmemize yardımcı olacak katkıda bulunanlara ihtiyacımız var. Yardımcı olmakla ilgileniyorsanız, lütfen daha fazla bilgi için [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) dosyasına bakın ve [Discord Topluluk Sunucumuzdaki](https://discord.gg/8Tpq4AcN9c) `global-users` kanalında bize bir yorum bırakın. diff --git a/README_TW.md b/README_TW.md index 35a01fa16a..18d0724784 100644 --- a/README_TW.md +++ b/README_TW.md @@ -173,7 +173,7 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify ## 貢獻 -對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 +對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TW.md)。 同時,也請考慮透過在社群媒體和各種活動與會議上分享 Dify 來支持我們。 > 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。 diff --git a/README_VI.md b/README_VI.md index f161b20f9d..6d5305fb75 100644 --- a/README_VI.md +++ b/README_VI.md @@ -162,7 +162,7 @@ Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure De ## Đóng góp -Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. +Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_VI.md) của chúng tôi. Đồng thời, vui lòng xem xét hỗ trợ Dify bằng cách chia sẻ nó trên mạng xã hội và tại các sự kiện và hội nghị. > Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi. diff --git a/api/.env.example b/api/.env.example index 3052dbfe2b..e947c5584b 100644 --- a/api/.env.example +++ b/api/.env.example @@ -564,3 +564,7 @@ QUEUE_MONITOR_THRESHOLD=200 QUEUE_MONITOR_ALERT_EMAILS= # Monitor interval in minutes, default is 30 minutes QUEUE_MONITOR_INTERVAL=30 + +# Swagger UI configuration +SWAGGER_UI_ENABLED=true +SWAGGER_UI_PATH=/swagger-ui.html diff --git a/api/README.md b/api/README.md index 8309a0e69b..d322963ffc 100644 --- a/api/README.md +++ b/api/README.md @@ -99,14 +99,14 @@ uv run celery -A app.celery beat 1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md) - ```cli - uv run --project api pytest # Run all tests - uv run --project api pytest tests/unit_tests/ # Unit tests only - uv run --project api pytest tests/integration_tests/ # Integration tests + ```bash + uv run pytest # Run all tests + uv run pytest tests/unit_tests/ # Unit tests only + uv run pytest tests/integration_tests/ # Integration tests # Code quality - ./dev/reformat # Run all formatters and linters - uv run --project api ruff check --fix ./ # Fix linting issues - uv run --project api ruff format ./ # Format code - uv run --project api mypy . # Type checking + ../dev/reformat # Run all formatters and linters + uv run ruff check --fix ./ # Fix linting issues + uv run ruff format ./ # Format code + uv run mypy . # Type checking ``` diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 2bccc4b7a0..7638cd1899 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,4 +1,4 @@ -from typing import Annotated, Literal, Optional +from typing import Literal, Optional from pydantic import ( AliasChoices, @@ -976,6 +976,18 @@ class WorkflowLogConfig(BaseSettings): ) +class SwaggerUIConfig(BaseSettings): + SWAGGER_UI_ENABLED: bool = Field( + description="Whether to enable Swagger UI in api module", + default=True, + ) + + SWAGGER_UI_PATH: str = Field( + description="Swagger UI page path in api module", + default="/swagger-ui.html", + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -1007,6 +1019,7 @@ class FeatureConfig( WorkspaceConfig, LoginConfig, AccountConfig, + SwaggerUIConfig, # hosted services config HostedServiceConfig, CeleryBeatConfig, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 57dbc8da64..e25f92399c 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -84,7 +84,6 @@ from .datasets import ( external, hit_testing, metadata, - upload_file, website, ) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 8dcffb1666..e840c00283 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -72,6 +72,7 @@ class DraftWorkflowApi(Resource): Get draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() @@ -94,6 +95,7 @@ class DraftWorkflowApi(Resource): Sync draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() @@ -171,6 +173,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): Run draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() @@ -218,13 +221,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): """ Run draft workflow iteration node """ + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() @@ -256,11 +258,10 @@ class WorkflowDraftRunIterationNodeApi(Resource): Run draft workflow iteration node """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() + if not current_user.is_editor: + raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") @@ -292,12 +293,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): """ Run draft workflow loop node """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() if not isinstance(current_user, Account): raise Forbidden() + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") @@ -329,12 +330,12 @@ class WorkflowDraftRunLoopNodeApi(Resource): """ Run draft workflow loop node """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() if not isinstance(current_user, Account): raise Forbidden() + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") @@ -366,12 +367,12 @@ class DraftWorkflowRunApi(Resource): """ Run draft workflow """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() if not isinstance(current_user, Account): raise Forbidden() + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") @@ -405,6 +406,9 @@ class WorkflowTaskStopApi(Resource): """ Stop workflow task """ + + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() @@ -424,12 +428,12 @@ class DraftWorkflowNodeRunApi(Resource): """ Run draft workflow node """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() if not isinstance(current_user, Account): raise Forbidden() + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") @@ -472,6 +476,9 @@ class PublishedWorkflowApi(Resource): """ Get published workflow """ + + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() @@ -491,13 +498,12 @@ class PublishedWorkflowApi(Resource): """ Publish workflow """ + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("marked_name", type=str, required=False, default="", location="json") parser.add_argument("marked_comment", type=str, required=False, default="", location="json") @@ -541,6 +547,9 @@ class DefaultBlockConfigsApi(Resource): """ Get default block config """ + + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() @@ -559,13 +568,12 @@ class DefaultBlockConfigApi(Resource): """ Get default block config """ + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("q", type=str, location="args") args = parser.parse_args() @@ -595,13 +603,12 @@ class ConvertToWorkflowApi(Resource): Convert expert mode of chatbot app to workflow mode Convert Completion App to Workflow App """ + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - if request.data: parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=False, nullable=True, location="json") @@ -645,6 +652,9 @@ class PublishedAllWorkflowApi(Resource): """ Get published workflows """ + + if not isinstance(current_user, Account): + raise Forbidden() if not current_user.is_editor: raise Forbidden() @@ -693,13 +703,12 @@ class WorkflowByIdApi(Resource): """ Update workflow attributes """ + if not isinstance(current_user, Account): + raise Forbidden() # Check permission if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("marked_name", type=str, required=False, location="json") parser.add_argument("marked_comment", type=str, required=False, location="json") @@ -750,13 +759,12 @@ class WorkflowByIdApi(Resource): """ Delete workflow """ + if not isinstance(current_user, Account): + raise Forbidden() # Check permission if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - workflow_service = WorkflowService() # Create a session and manage the transaction diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 4e625db24d..a0b73f7e07 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -21,6 +21,7 @@ from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type from libs.login import current_user, login_required from models import App, AppMode, db +from models.account import Account from models.workflow import WorkflowDraftVariable from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_service import WorkflowService @@ -135,6 +136,7 @@ def _api_prerequisite(f): @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) def wrapper(*args, **kwargs): + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() return f(*args, **kwargs) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 132dc1f96b..c7e300279a 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -6,9 +6,11 @@ from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user from models import App, AppMode +from models.account import Account def _load_app_model(app_id: str) -> Optional[App]: + assert isinstance(current_user, Account) app_model = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index 8c5e23de58..7853bef917 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -55,6 +55,12 @@ class EmailOrPasswordMismatchError(BaseHTTPException): code = 400 +class AuthenticationFailedError(BaseHTTPException): + error_code = "authentication_failed" + description = "Invalid email or password." + code = 401 + + class EmailPasswordLoginLimitError(BaseHTTPException): error_code = "email_code_login_limit" description = "Too many incorrect password attempts. Please try again later." diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index a5ad6a1cd7..6ed49f48ff 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -9,8 +9,8 @@ from configs import dify_config from constants.languages import languages from controllers.console import api from controllers.console.auth.error import ( + AuthenticationFailedError, EmailCodeError, - EmailOrPasswordMismatchError, EmailPasswordLoginLimitError, InvalidEmailError, InvalidTokenError, @@ -79,7 +79,7 @@ class LoginApi(Resource): raise AccountBannedError() except services.errors.account.AccountPasswordError: AccountService.add_login_error_rate_limit(args["email"]) - raise EmailOrPasswordMismatchError() + raise AuthenticationFailedError() except services.errors.account.AccountNotFoundError: if FeatureService.get_system_features().is_allow_register: token = AccountService.send_reset_password_email(email=args["email"], language=language) @@ -132,6 +132,7 @@ class ResetPasswordSendEmailApi(Resource): account = AccountService.get_user_through_email(args["email"]) except AccountRegisterError as are: raise AccountInFreezeError() + if account is None: if FeatureService.get_system_features().is_allow_register: token = AccountService.send_reset_password_email(email=args["email"], language=language) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index a23536f82e..a5a18e7f33 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -553,7 +553,7 @@ class DatasetIndexingStatusApi(Resource): } documents_status.append(marshal(document_dict, document_status_fields)) data = {"data": documents_status} - return data + return data, 200 class DatasetApiKeyApi(Resource): diff --git a/api/controllers/console/datasets/upload_file.py b/api/controllers/console/datasets/upload_file.py deleted file mode 100644 index 617dbcaff2..0000000000 --- a/api/controllers/console/datasets/upload_file.py +++ /dev/null @@ -1,62 +0,0 @@ -from flask_login import current_user -from flask_restx import Resource -from werkzeug.exceptions import NotFound - -from controllers.console import api -from controllers.console.wraps import ( - account_initialization_required, - setup_required, -) -from core.file import helpers as file_helpers -from extensions.ext_database import db -from models.dataset import Dataset -from models.model import UploadFile -from services.dataset_service import DocumentService - - -class UploadFileApi(Resource): - @setup_required - @account_initialization_required - def get(self, dataset_id, document_id): - """Get upload file.""" - # check dataset - dataset_id = str(dataset_id) - dataset = ( - db.session.query(Dataset) - .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id) - .first() - ) - if not dataset: - raise NotFound("Dataset not found.") - # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset.id, document_id) - if not document: - raise NotFound("Document not found.") - # check upload file - if document.data_source_type != "upload_file": - raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.") - data_source_info = document.data_source_info_dict - if data_source_info and "upload_file_id" in data_source_info: - file_id = data_source_info["upload_file_id"] - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() - if not upload_file: - raise NotFound("UploadFile not found.") - else: - raise ValueError("Upload file id not found in document data source info.") - - url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": url, - "download_url": f"{url}&as_attachment=true", - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at.timestamp(), - }, 200 - - -api.add_resource(UploadFileApi, "/datasets//documents//upload-file") diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 3d872fc1fc..c1848ceed1 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -43,7 +43,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() - + assert current_user is not None try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True @@ -76,6 +76,7 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() + assert current_user is not None AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 2a54511bf0..7c1bc7c075 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -6,7 +6,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_user, login_required -from models.account import TenantAccountRole +from models.account import Account, TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService @@ -15,10 +15,12 @@ class LoadBalancingCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + assert isinstance(current_user, Account) if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() tenant_id = current_user.current_tenant_id + assert tenant_id is not None parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") @@ -64,10 +66,12 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str, config_id: str): + assert isinstance(current_user, Account) if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() tenant_id = current_user.current_tenant_id + assert tenant_id is not None parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index f018fada3a..cf2a10f453 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -54,7 +54,7 @@ class MemberInviteEmailApi(Resource): @cloud_edition_billing_resource_check("members") def post(self): parser = reqparse.RequestParser() - parser.add_argument("emails", type=str, required=True, location="json", action="append") + parser.add_argument("emails", type=list, required=True, location="json") parser.add_argument("role", type=str, required=True, default="admin", location="json") parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 281783b3d7..3861fb8e99 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -10,6 +10,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder +from libs.helper import StrLen, uuid_value from libs.login import login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService @@ -45,12 +46,109 @@ class ModelProviderCredentialApi(Resource): @account_initialization_required def get(self, provider: str): tenant_id = current_user.current_tenant_id + # if credential_id is not provided, return current used credential + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") + args = parser.parse_args() model_provider_service = ModelProviderService() - credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) + credentials = model_provider_service.get_provider_credential( + tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id") + ) return {"credentials": credentials} + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + + try: + model_provider_service.create_provider_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + credentials=args["credentials"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + + @setup_required + @login_required + @account_initialization_required + def put(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + + try: + model_provider_service.update_provider_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + credentials=args["credentials"], + credential_id=args["credential_id"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"} + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + model_provider_service.remove_provider_credential( + tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] + ) + + return {"result": "success"}, 204 + + +class ModelProviderCredentialSwitchApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + service = ModelProviderService() + service.switch_active_provider_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + credential_id=args["credential_id"], + ) + return {"result": "success"} + class ModelProviderValidateApi(Resource): @setup_required @@ -69,7 +167,7 @@ class ModelProviderValidateApi(Resource): error = "" try: - model_provider_service.provider_credentials_validate( + model_provider_service.validate_provider_credentials( tenant_id=tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: @@ -84,42 +182,6 @@ class ModelProviderValidateApi(Resource): return response -class ModelProviderApi(Resource): - @setup_required - @login_required - @account_initialization_required - def post(self, provider: str): - if not current_user.is_admin_or_owner: - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - args = parser.parse_args() - - model_provider_service = ModelProviderService() - - try: - model_provider_service.save_provider_credentials( - tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] - ) - except CredentialsValidateFailedError as ex: - raise ValueError(str(ex)) - - return {"result": "success"}, 201 - - @setup_required - @login_required - @account_initialization_required - def delete(self, provider: str): - if not current_user.is_admin_or_owner: - raise Forbidden() - - model_provider_service = ModelProviderService() - model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) - - return {"result": "success"}, 204 - - class ModelProviderIconApi(Resource): """ Get model provider icon @@ -187,8 +249,10 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") +api.add_resource( + ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers//credentials/switch" +) api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") -api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") api.add_resource( PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index b8dddb91dd..98702dd3bc 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -9,6 +9,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder +from libs.helper import StrLen, uuid_value from libs.login import login_required from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService @@ -98,6 +99,7 @@ class ModelProviderModelApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + # To save the model's load balance configs if not current_user.is_admin_or_owner: raise Forbidden() @@ -113,22 +115,26 @@ class ModelProviderModelApi(Resource): choices=[mt.value for mt in ModelType], location="json", ) - parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") + parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json") args = parser.parse_args() + if args.get("config_from", "") == "custom-model": + if not args.get("credential_id"): + raise ValueError("credential_id is required when configuring a custom-model") + service = ModelProviderService() + service.switch_active_custom_model_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args["credential_id"], + ) + model_load_balancing_service = ModelLoadBalancingService() - if ( - "load_balancing" in args - and args["load_balancing"] - and "enabled" in args["load_balancing"] - and args["load_balancing"]["enabled"] - ): - if "configs" not in args["load_balancing"]: - raise ValueError("invalid load balancing configs") - + if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]: # save load balancing configs model_load_balancing_service.update_load_balancing_configs( tenant_id=tenant_id, @@ -136,37 +142,17 @@ class ModelProviderModelApi(Resource): model=args["model"], model_type=args["model_type"], configs=args["load_balancing"]["configs"], + config_from=args.get("config_from", ""), ) - # enable load balancing - model_load_balancing_service.enable_model_load_balancing( - tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] - ) - else: - # disable load balancing - model_load_balancing_service.disable_model_load_balancing( - tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] - ) - - if args.get("config_from", "") != "predefined-model": - model_provider_service = ModelProviderService() - - try: - model_provider_service.save_model_credentials( - tenant_id=tenant_id, - provider=provider, - model=args["model"], - model_type=args["model_type"], - credentials=args["credentials"], - ) - except CredentialsValidateFailedError as ex: - logging.exception( - "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", - tenant_id, - args.get("model"), - args.get("model_type"), - ) - raise ValueError(str(ex)) + if args.get("load_balancing", {}).get("enabled"): + model_load_balancing_service.enable_model_load_balancing( + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] + ) + else: + model_load_balancing_service.disable_model_load_balancing( + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] + ) return {"result": "success"}, 200 @@ -192,7 +178,7 @@ class ModelProviderModelApi(Resource): args = parser.parse_args() model_provider_service = ModelProviderService() - model_provider_service.remove_model_credentials( + model_provider_service.remove_model( tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) @@ -216,11 +202,17 @@ class ModelProviderModelCredentialApi(Resource): choices=[mt.value for mt in ModelType], location="args", ) + parser.add_argument("config_from", type=str, required=False, nullable=True, location="args") + parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") args = parser.parse_args() model_provider_service = ModelProviderService() - credentials = model_provider_service.get_model_credentials( - tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] + current_credential = model_provider_service.get_model_credential( + tenant_id=tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args.get("credential_id"), ) model_load_balancing_service = ModelLoadBalancingService() @@ -228,10 +220,173 @@ class ModelProviderModelCredentialApi(Resource): tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return { - "credentials": credentials, - "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, - } + if args.get("config_from", "") == "predefined-model": + available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( + tenant_id=tenant_id, provider_name=provider + ) + else: + model_type = ModelType.value_of(args["model_type"]).to_origin_model_type() + available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( + tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"] + ) + + return jsonable_encoder( + { + "credentials": current_credential.get("credentials") if current_credential else {}, + "current_credential_id": current_credential.get("current_credential_id") + if current_credential + else None, + "current_credential_name": current_credential.get("current_credential_name") + if current_credential + else None, + "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, + "available_credentials": available_credentials, + } + ) + + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + tenant_id = current_user.current_tenant_id + model_provider_service = ModelProviderService() + + try: + model_provider_service.create_model_credential( + tenant_id=tenant_id, + provider=provider, + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + logging.exception( + "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", + tenant_id, + args.get("model"), + args.get("model_type"), + ) + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + + @setup_required + @login_required + @account_initialization_required + def put(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + + try: + model_provider_service.update_model_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credentials=args["credentials"], + credential_id=args["credential_id"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"} + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + model_provider_service.remove_model_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args["credential_id"], + ) + + return {"result": "success"}, 204 + + +class ModelProviderModelCredentialSwitchApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + service = ModelProviderService() + service.add_model_credential_to_model_list( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args["credential_id"], + ) + return {"result": "success"} class ModelProviderModelEnableApi(Resource): @@ -314,7 +469,7 @@ class ModelProviderModelValidateApi(Resource): error = "" try: - model_provider_service.model_credentials_validate( + model_provider_service.validate_model_credentials( tenant_id=tenant_id, provider=provider, model=args["model"], @@ -379,6 +534,10 @@ api.add_resource( api.add_resource( ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" ) +api.add_resource( + ModelProviderModelCredentialSwitchApi, + "/workspaces/current/model-providers//models/credentials/switch", +) api.add_resource( ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" ) diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 282a181997..821ad220a2 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -13,7 +13,7 @@ api = ExternalApi( doc="/docs", # Enable Swagger UI at /files/docs ) -files_ns = Namespace("files", description="File operations") +files_ns = Namespace("files", description="File operations", path="/") from . import image_preview, tool_files, upload diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index d51db4322a..d29a7be139 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -1,10 +1,23 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") -api = ExternalApi(bp) + +api = ExternalApi( + bp, + version="1.0", + title="Inner API", + description="Internal APIs for enterprise features, billing, and plugin communication", + doc="/docs", # Enable Swagger UI at /inner/api/docs +) + +# Create namespace +inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") from . import mail from .plugin import plugin from .workspace import workspace + +api.add_namespace(inner_api_ns) diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py index 80bbc360de..0b2be03e43 100644 --- a/api/controllers/inner_api/mail.py +++ b/api/controllers/inner_api/mail.py @@ -1,7 +1,7 @@ from flask_restx import Resource, reqparse from controllers.console.wraps import setup_required -from controllers.inner_api import api +from controllers.inner_api import inner_api_ns from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only from tasks.mail_inner_task import send_inner_email_task @@ -26,13 +26,45 @@ class BaseMail(Resource): return {"message": "success"}, 200 +@inner_api_ns.route("/enterprise/mail") class EnterpriseMail(BaseMail): method_decorators = [setup_required, enterprise_inner_api_only] + @inner_api_ns.doc("send_enterprise_mail") + @inner_api_ns.doc(description="Send internal email for enterprise features") + @inner_api_ns.expect(_mail_parser) + @inner_api_ns.doc( + responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} + ) + def post(self): + """Send internal email for enterprise features. + This endpoint allows sending internal emails for enterprise-specific + notifications and communications. + + Returns: + dict: Success message with status code 200 + """ + return super().post() + + +@inner_api_ns.route("/billing/mail") class BillingMail(BaseMail): method_decorators = [setup_required, billing_inner_api_only] + @inner_api_ns.doc("send_billing_mail") + @inner_api_ns.doc(description="Send internal email for billing notifications") + @inner_api_ns.expect(_mail_parser) + @inner_api_ns.doc( + responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} + ) + def post(self): + """Send internal email for billing notifications. -api.add_resource(EnterpriseMail, "/enterprise/mail") -api.add_resource(BillingMail, "/billing/mail") + This endpoint allows sending internal emails for billing-related + notifications and alerts. + + Returns: + dict: Success message with status code 200 + """ + return super().post() diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 9b8d9457f0..170a794d89 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,7 +1,7 @@ from flask_restx import Resource from controllers.console.wraps import setup_required -from controllers.inner_api import api +from controllers.inner_api import inner_api_ns from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data from controllers.inner_api.wraps import plugin_inner_api_only from core.file.helpers import get_signed_file_url_for_plugin @@ -35,11 +35,21 @@ from models.account import Account, Tenant from models.model import EndUser +@inner_api_ns.route("/invoke/llm") class PluginInvokeLLMApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeLLM) + @inner_api_ns.doc("plugin_invoke_llm") + @inner_api_ns.doc(description="Invoke LLM models through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "LLM invocation successful (streaming response)", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM): def generator(): response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload) @@ -48,11 +58,21 @@ class PluginInvokeLLMApi(Resource): return length_prefixed_response(0xF, generator()) +@inner_api_ns.route("/invoke/llm/structured-output") class PluginInvokeLLMWithStructuredOutputApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput) + @inner_api_ns.doc("plugin_invoke_llm_structured") + @inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "LLM structured output invocation successful (streaming response)", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLMWithStructuredOutput): def generator(): response = PluginModelBackwardsInvocation.invoke_llm_with_structured_output( @@ -63,11 +83,21 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource): return length_prefixed_response(0xF, generator()) +@inner_api_ns.route("/invoke/text-embedding") class PluginInvokeTextEmbeddingApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeTextEmbedding) + @inner_api_ns.doc("plugin_invoke_text_embedding") + @inner_api_ns.doc(description="Invoke text embedding models through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "Text embedding successful", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding): try: return jsonable_encoder( @@ -83,11 +113,17 @@ class PluginInvokeTextEmbeddingApi(Resource): return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) +@inner_api_ns.route("/invoke/rerank") class PluginInvokeRerankApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeRerank) + @inner_api_ns.doc("plugin_invoke_rerank") + @inner_api_ns.doc(description="Invoke rerank models through plugin interface") + @inner_api_ns.doc( + responses={200: "Rerank successful", 401: "Unauthorized - invalid API key", 404: "Service not available"} + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank): try: return jsonable_encoder( @@ -103,11 +139,21 @@ class PluginInvokeRerankApi(Resource): return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) +@inner_api_ns.route("/invoke/tts") class PluginInvokeTTSApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeTTS) + @inner_api_ns.doc("plugin_invoke_tts") + @inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "TTS invocation successful (streaming response)", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS): def generator(): response = PluginModelBackwardsInvocation.invoke_tts( @@ -120,11 +166,17 @@ class PluginInvokeTTSApi(Resource): return length_prefixed_response(0xF, generator()) +@inner_api_ns.route("/invoke/speech2text") class PluginInvokeSpeech2TextApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeSpeech2Text) + @inner_api_ns.doc("plugin_invoke_speech2text") + @inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface") + @inner_api_ns.doc( + responses={200: "Speech2Text successful", 401: "Unauthorized - invalid API key", 404: "Service not available"} + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text): try: return jsonable_encoder( @@ -140,11 +192,17 @@ class PluginInvokeSpeech2TextApi(Resource): return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) +@inner_api_ns.route("/invoke/moderation") class PluginInvokeModerationApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeModeration) + @inner_api_ns.doc("plugin_invoke_moderation") + @inner_api_ns.doc(description="Invoke moderation models through plugin interface") + @inner_api_ns.doc( + responses={200: "Moderation successful", 401: "Unauthorized - invalid API key", 404: "Service not available"} + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration): try: return jsonable_encoder( @@ -160,11 +218,21 @@ class PluginInvokeModerationApi(Resource): return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) +@inner_api_ns.route("/invoke/tool") class PluginInvokeToolApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeTool) + @inner_api_ns.doc("plugin_invoke_tool") + @inner_api_ns.doc(description="Invoke tools through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "Tool invocation successful (streaming response)", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool): def generator(): return PluginToolBackwardsInvocation.convert_to_event_stream( @@ -182,11 +250,21 @@ class PluginInvokeToolApi(Resource): return length_prefixed_response(0xF, generator()) +@inner_api_ns.route("/invoke/parameter-extractor") class PluginInvokeParameterExtractorNodeApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeParameterExtractorNode) + @inner_api_ns.doc("plugin_invoke_parameter_extractor") + @inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "Parameter extraction successful", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode): try: return jsonable_encoder( @@ -205,11 +283,21 @@ class PluginInvokeParameterExtractorNodeApi(Resource): return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) +@inner_api_ns.route("/invoke/question-classifier") class PluginInvokeQuestionClassifierNodeApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) + @inner_api_ns.doc("plugin_invoke_question_classifier") + @inner_api_ns.doc(description="Invoke question classifier node through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "Question classification successful", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode): try: return jsonable_encoder( @@ -228,11 +316,21 @@ class PluginInvokeQuestionClassifierNodeApi(Resource): return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) +@inner_api_ns.route("/invoke/app") class PluginInvokeAppApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeApp) + @inner_api_ns.doc("plugin_invoke_app") + @inner_api_ns.doc(description="Invoke application through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "App invocation successful (streaming response)", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp): response = PluginAppBackwardsInvocation.invoke_app( app_id=payload.app_id, @@ -248,11 +346,21 @@ class PluginInvokeAppApi(Resource): return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response)) +@inner_api_ns.route("/invoke/encrypt") class PluginInvokeEncryptApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeEncrypt) + @inner_api_ns.doc("plugin_invoke_encrypt") + @inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "Encryption/decryption successful", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt): """ encrypt or decrypt data @@ -265,11 +373,21 @@ class PluginInvokeEncryptApi(Resource): return BaseBackwardsInvocationResponse(error=str(e)).model_dump() +@inner_api_ns.route("/invoke/summary") class PluginInvokeSummaryApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestInvokeSummary) + @inner_api_ns.doc("plugin_invoke_summary") + @inner_api_ns.doc(description="Invoke summary functionality through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "Summary generation successful", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary): try: return BaseBackwardsInvocationResponse( @@ -285,40 +403,43 @@ class PluginInvokeSummaryApi(Resource): return BaseBackwardsInvocationResponse(error=str(e)).model_dump() +@inner_api_ns.route("/upload/file/request") class PluginUploadFileRequestApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestRequestUploadFile) + @inner_api_ns.doc("plugin_upload_file_request") + @inner_api_ns.doc(description="Request signed URL for file upload through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "Signed URL generated successfully", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile): # generate signed url url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id) return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() +@inner_api_ns.route("/fetch/app/info") class PluginFetchAppInfoApi(Resource): @setup_required @plugin_inner_api_only @get_user_tenant @plugin_data(payload_type=RequestFetchAppInfo) + @inner_api_ns.doc("plugin_fetch_app_info") + @inner_api_ns.doc(description="Fetch application information through plugin interface") + @inner_api_ns.doc( + responses={ + 200: "App information retrieved successfully", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo): return BaseBackwardsInvocationResponse( data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id) ).model_dump() - - -api.add_resource(PluginInvokeLLMApi, "/invoke/llm") -api.add_resource(PluginInvokeLLMWithStructuredOutputApi, "/invoke/llm/structured-output") -api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") -api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") -api.add_resource(PluginInvokeTTSApi, "/invoke/tts") -api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text") -api.add_resource(PluginInvokeModerationApi, "/invoke/moderation") -api.add_resource(PluginInvokeToolApi, "/invoke/tool") -api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor") -api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier") -api.add_resource(PluginInvokeAppApi, "/invoke/app") -api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") -api.add_resource(PluginInvokeSummaryApi, "/invoke/summary") -api.add_resource(PluginUploadFileRequestApi, "/upload/file/request") -api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info") diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 1c26416080..47f0240cd2 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -3,7 +3,7 @@ import json from flask_restx import Resource, reqparse from controllers.console.wraps import setup_required -from controllers.inner_api import api +from controllers.inner_api import inner_api_ns from controllers.inner_api.wraps import enterprise_inner_api_only from events.tenant_event import tenant_was_created from extensions.ext_database import db @@ -11,9 +11,19 @@ from models.account import Account from services.account_service import TenantService +@inner_api_ns.route("/enterprise/workspace") class EnterpriseWorkspace(Resource): @setup_required @enterprise_inner_api_only + @inner_api_ns.doc("create_enterprise_workspace") + @inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment") + @inner_api_ns.doc( + responses={ + 200: "Workspace created successfully", + 401: "Unauthorized - invalid API key", + 404: "Owner account not found or service not available", + } + ) def post(self): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") @@ -44,9 +54,19 @@ class EnterpriseWorkspace(Resource): } +@inner_api_ns.route("/enterprise/workspace/ownerless") class EnterpriseWorkspaceNoOwnerEmail(Resource): @setup_required @enterprise_inner_api_only + @inner_api_ns.doc("create_enterprise_workspace_ownerless") + @inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment") + @inner_api_ns.doc( + responses={ + 200: "Workspace created successfully", + 401: "Unauthorized - invalid API key", + 404: "Service not available", + } + ) def post(self): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") @@ -71,7 +91,3 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource): "message": "enterprise workspace created.", "tenant": resp, } - - -api.add_resource(EnterpriseWorkspace, "/enterprise/workspace") -api.add_resource(EnterpriseWorkspaceNoOwnerEmail, "/enterprise/workspace/ownerless") diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py index 1f5dae74e8..c344ffad08 100644 --- a/api/controllers/mcp/__init__.py +++ b/api/controllers/mcp/__init__.py @@ -13,7 +13,7 @@ api = ExternalApi( doc="/docs", # Enable Swagger UI at /mcp/docs ) -mcp_ns = Namespace("mcp", description="MCP operations") +mcp_ns = Namespace("mcp", description="MCP operations", path="/") from . import mcp diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index aaa3c8f9a1..763345d723 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -13,7 +13,7 @@ api = ExternalApi( doc="/docs", # Enable Swagger UI at /v1/docs ) -service_api_ns = Namespace("service_api", description="Service operations") +service_api_ns = Namespace("service_api", description="Service operations", path="/") from . import index from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 6bc94af8c1..9038bda11a 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -10,6 +10,7 @@ from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client from fields.annotation_fields import annotation_fields, build_annotation_model from libs.login import current_user +from models.account import Account from models.model import App from services.annotation_service import AppAnnotationService @@ -163,6 +164,7 @@ class AnnotationUpdateDeleteApi(Resource): @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id): """Update an existing annotation.""" + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() @@ -185,6 +187,8 @@ class AnnotationUpdateDeleteApi(Resource): @validate_app_token def delete(self, app_model: App, annotation_id): """Delete an annotation.""" + assert isinstance(current_user, Account) + if not current_user.is_editor: raise Forbidden() diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index c486b0480b..7b74c961bb 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -18,6 +18,7 @@ from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import build_dataset_tag_fields from libs.login import current_user +from models.account import Account from models.dataset import Dataset, DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel @@ -213,7 +214,10 @@ class DatasetListApi(DatasetApiResource): ) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) + assert isinstance(current_user, Account) + cid = current_user.current_tenant_id + assert cid is not None + configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -266,6 +270,7 @@ class DatasetListApi(DatasetApiResource): ) try: + assert isinstance(current_user, Account) dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, name=args["name"], @@ -319,7 +324,10 @@ class DatasetApi(DatasetApiResource): # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) + assert isinstance(current_user, Account) + cid = current_user.current_tenant_id + assert cid is not None + configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -391,6 +399,7 @@ class DatasetApi(DatasetApiResource): raise NotFound("Dataset not found.") result_data = marshal(dataset, dataset_detail_fields) + assert isinstance(current_user, Account) tenant_id = current_user.current_tenant_id if data.get("partial_member_list") and data.get("permission") == "partial_members": @@ -532,7 +541,10 @@ class DatasetTagsApi(DatasetApiResource): @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def get(self, _, dataset_id): """Get all knowledge type tags.""" - tags = TagService.get_tags("knowledge", current_user.current_tenant_id) + assert isinstance(current_user, Account) + cid = current_user.current_tenant_id + assert cid is not None + tags = TagService.get_tags("knowledge", cid) return tags, 200 @@ -550,6 +562,7 @@ class DatasetTagsApi(DatasetApiResource): @validate_dataset_token def post(self, _, dataset_id): """Add a knowledge type tag.""" + assert isinstance(current_user, Account) if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() @@ -573,6 +586,7 @@ class DatasetTagsApi(DatasetApiResource): @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @validate_dataset_token def patch(self, _, dataset_id): + assert isinstance(current_user, Account) if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() @@ -599,6 +613,7 @@ class DatasetTagsApi(DatasetApiResource): @validate_dataset_token def delete(self, _, dataset_id): """Delete a knowledge type tag.""" + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() args = tag_delete_parser.parse_args() @@ -622,6 +637,7 @@ class DatasetTagBindingApi(DatasetApiResource): @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + assert isinstance(current_user, Account) if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() @@ -647,6 +663,7 @@ class DatasetTagUnbindingApi(DatasetApiResource): @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + assert isinstance(current_user, Account) if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() @@ -672,6 +689,8 @@ class DatasetTagsBindingStatusApi(DatasetApiResource): def get(self, _, *args, **kwargs): """Get all knowledge type tags.""" dataset_id = kwargs.get("dataset_id") + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id)) tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] response = {"data": tags_list, "total": len(tags)} diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index d436657f06..b596d969a9 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -7,13 +7,14 @@ from sqlalchemy import select from sqlalchemy.orm import Session from controllers.console.auth.error import ( + AuthenticationFailedError, EmailCodeError, EmailPasswordResetLimitError, InvalidEmailError, InvalidTokenError, PasswordMismatchError, ) -from controllers.console.error import AccountNotFound, EmailSendIpLimitError +from controllers.console.error import EmailSendIpLimitError from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required from controllers.web import api from extensions.ext_database import db @@ -46,7 +47,7 @@ class ForgotPasswordSendEmailApi(Resource): account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() token = None if account is None: - raise AccountNotFound() + raise AuthenticationFailedError() else: token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) @@ -131,7 +132,7 @@ class ForgotPasswordResetApi(Resource): if account: self._update_existing_account(account, password_hashed, salt, session) else: - raise AccountNotFound() + raise AuthenticationFailedError() return {"result": "success"} diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index d4eafd532b..4f6ff5c1a4 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -2,8 +2,12 @@ from flask_restx import Resource, reqparse from jwt import InvalidTokenError # type: ignore import services -from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError -from controllers.console.error import AccountBannedError, AccountNotFound +from controllers.console.auth.error import ( + AuthenticationFailedError, + EmailCodeError, + InvalidEmailError, +) +from controllers.console.error import AccountBannedError from controllers.console.wraps import only_edition_enterprise, setup_required from controllers.web import api from libs.helper import email @@ -29,9 +33,9 @@ class LoginApi(Resource): except services.errors.account.AccountLoginError: raise AccountBannedError() except services.errors.account.AccountPasswordError: - raise EmailOrPasswordMismatchError() + raise AuthenticationFailedError() except services.errors.account.AccountNotFoundError: - raise AccountNotFound() + raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) return {"result": "success", "data": {"access_token": token}} @@ -63,7 +67,7 @@ class EmailCodeLoginSendEmailApi(Resource): account = WebAppAuthService.get_user_through_email(args["email"]) if account is None: - raise AccountNotFound() + raise AuthenticationFailedError() else: token = WebAppAuthService.send_email_code_login_email(account=account, language=language) @@ -95,7 +99,7 @@ class EmailCodeLoginApi(Resource): WebAppAuthService.revoke_email_code_login_token(args["token"]) account = WebAppAuthService.get_user_through_email(user_email) if not account: - raise AccountNotFound() + raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) AccountService.reset_login_error_rate_limit(args["email"]) diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index e1c021a44a..ac64a8e3a0 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -19,6 +19,7 @@ class ModelStatus(Enum): QUOTA_EXCEEDED = "quota-exceeded" NO_PERMISSION = "no-permission" DISABLED = "disabled" + CREDENTIAL_REMOVED = "credential-removed" class SimpleModelProviderEntity(BaseModel): @@ -54,6 +55,7 @@ class ProviderModelWithStatusEntity(ProviderModel): status: ModelStatus load_balancing_enabled: bool = False + has_invalid_load_balancing_configs: bool = False def raise_for_status(self) -> None: """ diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 646e0e21e9..ca3c36b878 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -6,6 +6,8 @@ from json import JSONDecodeError from typing import Optional from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import func, select +from sqlalchemy.orm import Session from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity @@ -32,7 +34,9 @@ from libs.datetime_utils import naive_utc_now from models.provider import ( LoadBalancingModelConfig, Provider, + ProviderCredential, ProviderModel, + ProviderModelCredential, ProviderModelSetting, ProviderType, TenantPreferredModelProvider, @@ -45,7 +49,16 @@ original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {} class ProviderConfiguration(BaseModel): """ - Model class for provider configuration. + Provider configuration entity for managing model provider settings. + + This class handles: + - Provider credentials CRUD and switch + - Custom Model credentials CRUD and switch + - System vs custom provider switching + - Load balancing configurations + - Model enablement/disablement + + TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ tenant_id: str @@ -155,33 +168,17 @@ class ProviderConfiguration(BaseModel): Check custom configuration available. :return: """ - return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 - - def get_custom_credentials(self, obfuscated: bool = False) -> dict | None: - """ - Get custom credentials. - - :param obfuscated: obfuscated secret data in credentials - :return: - """ - if self.custom_configuration.provider is None: - return None - - credentials = self.custom_configuration.provider.credentials - if not obfuscated: - return credentials - - # Obfuscate credentials - return self.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema - else [], + has_provider_credentials = ( + self.custom_configuration.provider is not None + and len(self.custom_configuration.provider.available_credentials) > 0 ) - def _get_custom_provider_credentials(self) -> Provider | None: + has_model_configurations = len(self.custom_configuration.models) > 0 + return has_provider_credentials or has_model_configurations + + def _get_provider_record(self, session: Session) -> Provider | None: """ - Get custom provider credentials. + Get custom provider record. """ # get provider model_provider_id = ModelProviderID(self.provider.provider) @@ -189,156 +186,442 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - provider_record = ( - db.session.query(Provider) - .where( - Provider.tenant_id == self.tenant_id, - Provider.provider_type == ProviderType.CUSTOM.value, - Provider.provider_name.in_(provider_names), - ) - .first() + stmt = select(Provider).where( + Provider.tenant_id == self.tenant_id, + Provider.provider_type == ProviderType.CUSTOM.value, + Provider.provider_name.in_(provider_names), ) - return provider_record + return session.execute(stmt).scalar_one_or_none() - def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: + def _get_specific_provider_credential(self, credential_id: str) -> dict | None: """ - Validate custom credentials. - :param credentials: provider credentials + Get a specific provider credential by ID. + :param credential_id: Credential ID :return: """ - provider_record = self._get_custom_provider_credentials() - - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( + # Extract secret variables from provider credential schema + credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] ) - if provider_record: - try: - # fix origin data - if provider_record.encrypted_config: - if not provider_record.encrypted_config.startswith("{"): - original_credentials = {"openai_api_key": provider_record.encrypted_config} - else: - original_credentials = json.loads(provider_record.encrypted_config) - else: - original_credentials = {} - except JSONDecodeError: - original_credentials = {} + with Session(db.engine) as session: + # Prefer the actual provider record name if exists (to handle aliased provider names) + provider_record = self._get_provider_record(session) + provider_name = provider_record.provider_name if provider_record else self.provider.provider - # encrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) - - model_provider_factory = ModelProviderFactory(self.tenant_id) - credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, credentials=credentials - ) - - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - credentials[key] = encrypter.encrypt_token(self.tenant_id, value) - - return provider_record, credentials - - def add_or_update_custom_credentials(self, credentials: dict) -> None: - """ - Add or update custom provider credentials. - :param credentials: - :return: - """ - # validate custom provider config - provider_record, credentials = self.custom_credentials_validate(credentials) - - # save provider - # Note: Do not switch the preferred provider, which allows users to use quotas first - if provider_record: - provider_record.encrypted_config = json.dumps(credentials) - provider_record.is_valid = True - provider_record.updated_at = naive_utc_now() - db.session.commit() - else: - provider_record = Provider() - provider_record.tenant_id = self.tenant_id - provider_record.provider_name = self.provider.provider - provider_record.provider_type = ProviderType.CUSTOM.value - provider_record.encrypted_config = json.dumps(credentials) - provider_record.is_valid = True - - db.session.add(provider_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER - ) - - provider_model_credentials_cache.delete() - - self.switch_preferred_provider_type(ProviderType.CUSTOM) - - def delete_custom_credentials(self) -> None: - """ - Delete custom provider credentials. - :return: - """ - # get provider - provider_record = self._get_custom_provider_credentials() - - # delete provider - if provider_record: - self.switch_preferred_provider_type(ProviderType.SYSTEM) - - db.session.delete(provider_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER, + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == provider_name, ) - provider_model_credentials_cache.delete() + credential = session.execute(stmt).scalar_one_or_none() - def get_custom_model_credentials( - self, model_type: ModelType, model: str, obfuscated: bool = False - ) -> Optional[dict]: + if not credential or not credential.encrypted_config: + raise ValueError(f"Credential with id {credential_id} not found.") + + try: + credentials = json.loads(credential.encrypted_config) + except JSONDecodeError: + credentials = {} + + # Decrypt secret variables + for key in credential_secret_variables: + if key in credentials and credentials[key] is not None: + try: + credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) + except Exception: + pass + + return self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [], + ) + + def _check_provider_credential_name_exists( + self, credential_name: str, session: Session, exclude_id: str | None = None + ) -> bool: """ - Get custom model credentials. + not allowed same name when create or update a credential + """ + stmt = select(ProviderCredential.id).where( + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.credential_name == credential_name, + ) + if exclude_id: + stmt = stmt.where(ProviderCredential.id != exclude_id) + return session.execute(stmt).scalar_one_or_none() is not None - :param model_type: model type - :param model: model name - :param obfuscated: obfuscated secret data in credentials + def get_provider_credential(self, credential_id: str | None = None) -> dict | None: + """ + Get provider credentials. + + :param credential_id: if provided, return the specified credential :return: """ - if not self.custom_configuration.models: - return None - for model_configuration in self.custom_configuration.models: - if model_configuration.model_type == model_type and model_configuration.model == model: - credentials = model_configuration.credentials - if not obfuscated: - return credentials + if credential_id: + return self._get_specific_provider_credential(credential_id) - # Obfuscate credentials - return self.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema - else [], + # Default behavior: return current active provider credentials + credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {} + + return self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [], + ) + + def validate_provider_credentials( + self, credentials: dict, credential_id: str = "", session: Session | None = None + ) -> dict: + """ + Validate custom credentials. + :param credentials: provider credentials + :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate + :param session: optional database session + :return: + """ + + def _validate(s: Session) -> dict: + # Get provider credential secret variables + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [] + ) + + if credential_id: + try: + stmt = select(ProviderCredential).where( + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.id == credential_id, + ) + credential_record = s.execute(stmt).scalar_one_or_none() + # fix origin data + if credential_record and credential_record.encrypted_config: + if not credential_record.encrypted_config.startswith("{"): + original_credentials = {"openai_api_key": credential_record.encrypted_config} + else: + original_credentials = json.loads(credential_record.encrypted_config) + else: + original_credentials = {} + except JSONDecodeError: + original_credentials = {} + + # encrypt credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) + + model_provider_factory = ModelProviderFactory(self.tenant_id) + validated_credentials = model_provider_factory.provider_credentials_validate( + provider=self.provider.provider, credentials=credentials + ) + + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables: + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials + + if session: + return _validate(session) + else: + with Session(db.engine) as new_session: + return _validate(new_session) + + def create_provider_credential(self, credentials: dict, credential_name: str) -> None: + """ + Add custom provider credentials. + :param credentials: provider credentials + :param credential_name: credential name + :return: + """ + with Session(db.engine) as session: + if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): + raise ValueError(f"Credential with name '{credential_name}' already exists.") + + credentials = self.validate_provider_credentials(credentials=credentials, session=session) + provider_record = self._get_provider_record(session) + try: + new_record = ProviderCredential( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + encrypted_config=json.dumps(credentials), + credential_name=credential_name, ) + session.add(new_record) + session.flush() - return None + if not provider_record: + # If provider record does not exist, create it + provider_record = Provider( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + provider_type=ProviderType.CUSTOM.value, + is_valid=True, + credential_id=new_record.id, + ) + session.add(provider_record) - def _get_custom_model_credentials( + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) + + session.commit() + except Exception: + session.rollback() + raise + + def update_provider_credential( + self, + credentials: dict, + credential_id: str, + credential_name: str, + ) -> None: + """ + update a saved provider credential (by credential_id). + + :param credentials: provider credentials + :param credential_id: credential id + :param credential_name: credential name + :return: + """ + with Session(db.engine) as session: + if self._check_provider_credential_name_exists( + credential_name=credential_name, session=session, exclude_id=credential_id + ): + raise ValueError(f"Credential with name '{credential_name}' already exists.") + + credentials = self.validate_provider_credentials( + credentials=credentials, credential_id=credential_id, session=session + ) + provider_record = self._get_provider_record(session) + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + + # Get the credential record to update + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + try: + # Update credential + credential_record.encrypted_config = json.dumps(credentials) + credential_record.credential_name = credential_name + credential_record.updated_at = naive_utc_now() + + session.commit() + + if provider_record and provider_record.credential_id == credential_id: + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + self._update_load_balancing_configs_with_credential( + credential_id=credential_id, + credential_record=credential_record, + credential_source="provider", + session=session, + ) + except Exception: + session.rollback() + raise + + def _update_load_balancing_configs_with_credential( + self, + credential_id: str, + credential_record: ProviderCredential | ProviderModelCredential, + credential_source: str, + session: Session, + ) -> None: + """ + Update load balancing configurations that reference the given credential_id. + + :param credential_id: credential id + :param credential_record: the encrypted_config and credential_name + :param credential_source: the credential comes from the provider_credential(`provider`) + or the provider_model_credential(`custom_model`) + :param session: the database session + :return: + """ + # Find all load balancing configs that use this credential_id + stmt = select(LoadBalancingModelConfig).where( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.credential_id == credential_id, + LoadBalancingModelConfig.credential_source_type == credential_source, + ) + load_balancing_configs = session.execute(stmt).scalars().all() + + if not load_balancing_configs: + return + + # Update each load balancing config with the new credentials + for lb_config in load_balancing_configs: + # Update the encrypted_config with the new credentials + lb_config.encrypted_config = credential_record.encrypted_config + lb_config.name = credential_record.credential_name + lb_config.updated_at = naive_utc_now() + + # Clear cache for this load balancing config + lb_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=lb_config.id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + ) + lb_credentials_cache.delete() + + session.commit() + + def delete_provider_credential(self, credential_id: str) -> None: + """ + Delete a saved provider credential (by credential_id). + + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + + # Get the credential record to update + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + # Check if this credential is used in load balancing configs + lb_stmt = select(LoadBalancingModelConfig).where( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.credential_id == credential_id, + LoadBalancingModelConfig.credential_source_type == "provider", + ) + lb_configs_using_credential = session.execute(lb_stmt).scalars().all() + try: + for lb_config in lb_configs_using_credential: + lb_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=lb_config.id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + ) + lb_credentials_cache.delete() + + lb_config.credential_id = None + lb_config.encrypted_config = None + lb_config.enabled = False + lb_config.name = "__delete__" + lb_config.updated_at = naive_utc_now() + session.add(lb_config) + + # Check if this is the currently active credential + provider_record = self._get_provider_record(session) + + # Check available credentials count BEFORE deleting + # if this is the last credential, we need to delete the provider record + count_stmt = select(func.count(ProviderCredential.id)).where( + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + available_credentials_count = session.execute(count_stmt).scalar() or 0 + session.delete(credential_record) + + if provider_record and available_credentials_count <= 1: + # If all credentials are deleted, delete the provider record + session.delete(provider_record) + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session) + elif provider_record and provider_record.credential_id == credential_id: + provider_record.credential_id = None + provider_record.updated_at = naive_utc_now() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session) + + session.commit() + except Exception: + session.rollback() + raise + + def switch_active_provider_credential(self, credential_id: str) -> None: + """ + Switch active provider credential (copy the selected one into current active snapshot). + + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + provider_record = self._get_provider_record(session) + if not provider_record: + raise ValueError("Provider record not found.") + + try: + provider_record.credential_id = credential_record.id + provider_record.updated_at = naive_utc_now() + session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session) + except Exception: + session.rollback() + raise + + def _get_custom_model_record( self, model_type: ModelType, model: str, + session: Session, ) -> ProviderModel | None: """ Get custom model credentials. @@ -349,128 +632,495 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - provider_model_record = ( - db.session.query(ProviderModel) - .where( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name.in_(provider_names), - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type(), - ) - .first() + stmt = select(ProviderModel).where( + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name.in_(provider_names), + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), ) - return provider_model_record + return session.execute(stmt).scalar_one_or_none() - def custom_model_credentials_validate( - self, model_type: ModelType, model: str, credentials: dict - ) -> tuple[ProviderModel | None, dict]: + def _get_specific_custom_model_credential( + self, model_type: ModelType, model: str, credential_id: str + ) -> dict | None: """ - Validate custom model credentials. - - :param model_type: model type - :param model: model name - :param credentials: model credentials + Get a specific provider credential by ID. + :param credential_id: Credential ID :return: """ - # get provider model - provider_model_record = self._get_custom_model_credentials(model_type, model) - - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( + model_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [] ) - if provider_model_record: - try: - original_credentials = ( - json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} - ) - except JSONDecodeError: - original_credentials = {} - - # decrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) - - model_provider_factory = ModelProviderFactory(self.tenant_id) - credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials - ) - - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - credentials[key] = encrypter.encrypt_token(self.tenant_id, value) - - return provider_model_record, credentials - - def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None: - """ - Add or update custom model credentials. - - :param model_type: model type - :param model: model name - :param credentials: model credentials - :return: - """ - # validate custom model config - provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials) - - # save provider model - # Note: Do not switch the preferred provider, which allows users to use quotas first - if provider_model_record: - provider_model_record.encrypted_config = json.dumps(credentials) - provider_model_record.is_valid = True - provider_model_record.updated_at = naive_utc_now() - db.session.commit() - else: - provider_model_record = ProviderModel() - provider_model_record.tenant_id = self.tenant_id - provider_model_record.provider_name = self.provider.provider - provider_model_record.model_name = model - provider_model_record.model_type = model_type.to_origin_model_type() - provider_model_record.encrypted_config = json.dumps(credentials) - provider_model_record.is_valid = True - db.session.add(provider_model_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL, - ) - - provider_model_credentials_cache.delete() - - def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: - """ - Delete custom model credentials. - :param model_type: model type - :param model: model name - :return: - """ - # get provider model - provider_model_record = self._get_custom_model_credentials(model_type, model) - - # delete provider model - if provider_model_record: - db.session.delete(provider_model_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL, + with Session(db.engine) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) - provider_model_credentials_cache.delete() + credential_record = session.execute(stmt).scalar_one_or_none() - def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None: + if not credential_record or not credential_record.encrypted_config: + raise ValueError(f"Credential with id {credential_id} not found.") + + try: + credentials = json.loads(credential_record.encrypted_config) + except JSONDecodeError: + credentials = {} + + # Decrypt secret variables + for key in model_credential_secret_variables: + if key in credentials and credentials[key] is not None: + try: + credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) + except Exception: + pass + + current_credential_id = credential_record.id + current_credential_name = credential_record.credential_name + credentials = self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [], + ) + + return { + "current_credential_id": current_credential_id, + "current_credential_name": current_credential_name, + "credentials": credentials, + } + + def _check_custom_model_credential_name_exists( + self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None + ) -> bool: + """ + not allowed same name when create or update a credential + """ + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.credential_name == credential_name, + ) + if exclude_id: + stmt = stmt.where(ProviderModelCredential.id != exclude_id) + return session.execute(stmt).scalar_one_or_none() is not None + + def get_custom_model_credential( + self, model_type: ModelType, model: str, credential_id: str | None + ) -> Optional[dict]: + """ + Get custom model credentials. + + :param model_type: model type + :param model: model name + :return: + """ + # If credential_id is provided, return the specific credential + if credential_id: + return self._get_specific_custom_model_credential( + model_type=model_type, model=model, credential_id=credential_id + ) + + for model_configuration in self.custom_configuration.models: + if ( + model_configuration.model_type == model_type + and model_configuration.model == model + and model_configuration.credentials + ): + current_credential_id = model_configuration.current_credential_id + current_credential_name = model_configuration.current_credential_name + credentials = self.obfuscated_credentials( + credentials=model_configuration.credentials, + credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [], + ) + return { + "current_credential_id": current_credential_id, + "current_credential_name": current_credential_name, + "credentials": credentials, + } + return None + + def validate_custom_model_credentials( + self, + model_type: ModelType, + model: str, + credentials: dict, + credential_id: str = "", + session: Session | None = None, + ) -> dict: + """ + Validate custom model credentials. + + :param model_type: model type + :param model: model name + :param credentials: model credentials dict + :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate + :return: + """ + + def _validate(s: Session) -> dict: + # Get provider credential secret variables + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [] + ) + + if credential_id: + try: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = s.execute(stmt).scalar_one_or_none() + original_credentials = ( + json.loads(credential_record.encrypted_config) + if credential_record and credential_record.encrypted_config + else {} + ) + except JSONDecodeError: + original_credentials = {} + + # decrypt credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) + + model_provider_factory = ModelProviderFactory(self.tenant_id) + validated_credentials = model_provider_factory.model_credentials_validate( + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials + ) + + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables: + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials + + if session: + return _validate(session) + else: + with Session(db.engine) as new_session: + return _validate(new_session) + + def create_custom_model_credential( + self, model_type: ModelType, model: str, credentials: dict, credential_name: str + ) -> None: + """ + Create a custom model credential. + + :param model_type: model type + :param model: model name + :param credentials: model credentials dict + :return: + """ + with Session(db.engine) as session: + if self._check_custom_model_credential_name_exists( + model=model, model_type=model_type, credential_name=credential_name, session=session + ): + raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") + # validate custom model config + credentials = self.validate_custom_model_credentials( + model_type=model_type, model=model, credentials=credentials, session=session + ) + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + try: + credential = ProviderModelCredential( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_name=model, + model_type=model_type.to_origin_model_type(), + encrypted_config=json.dumps(credentials), + credential_name=credential_name, + ) + session.add(credential) + session.flush() + + # save provider model + if not provider_model_record: + provider_model_record = ProviderModel( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_name=model, + model_type=model_type.to_origin_model_type(), + credential_id=credential.id, + is_valid=True, + ) + session.add(provider_model_record) + + session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + provider_model_credentials_cache.delete() + except Exception: + session.rollback() + raise + + def update_custom_model_credential( + self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str + ) -> None: + """ + Update a custom model credential. + + :param model_type: model type + :param model: model name + :param credentials: model credentials dict + :param credential_name: credential name + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + if self._check_custom_model_credential_name_exists( + model=model, + model_type=model_type, + credential_name=credential_name, + session=session, + exclude_id=credential_id, + ): + raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") + # validate custom model config + credentials = self.validate_custom_model_credentials( + model_type=model_type, + model=model, + credentials=credentials, + credential_id=credential_id, + session=session, + ) + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + try: + # Update credential + credential_record.encrypted_config = json.dumps(credentials) + credential_record.credential_name = credential_name + credential_record.updated_at = naive_utc_now() + session.commit() + + if provider_model_record and provider_model_record.credential_id == credential_id: + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + provider_model_credentials_cache.delete() + + self._update_load_balancing_configs_with_credential( + credential_id=credential_id, + credential_record=credential_record, + credential_source="custom_model", + session=session, + ) + except Exception: + session.rollback() + raise + + def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: + """ + Delete a saved provider credential (by credential_id). + + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + lb_stmt = select(LoadBalancingModelConfig).where( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.credential_id == credential_id, + LoadBalancingModelConfig.credential_source_type == "custom_model", + ) + lb_configs_using_credential = session.execute(lb_stmt).scalars().all() + + try: + for lb_config in lb_configs_using_credential: + lb_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=lb_config.id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + ) + lb_credentials_cache.delete() + lb_config.credential_id = None + lb_config.encrypted_config = None + lb_config.enabled = False + lb_config.name = "__delete__" + lb_config.updated_at = naive_utc_now() + session.add(lb_config) + + # Check if this is the currently active credential + provider_model_record = self._get_custom_model_record(model_type, model, session=session) + + # Check available credentials count BEFORE deleting + # if this is the last credential, we need to delete the custom model record + count_stmt = select(func.count(ProviderModelCredential.id)).where( + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + available_credentials_count = session.execute(count_stmt).scalar() or 0 + session.delete(credential_record) + + if provider_model_record and available_credentials_count <= 1: + # If all credentials are deleted, delete the custom model record + session.delete(provider_model_record) + elif provider_model_record and provider_model_record.credential_id == credential_id: + provider_model_record.credential_id = None + provider_model_record.updated_at = naive_utc_now() + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + session.commit() + + except Exception: + session.rollback() + raise + + def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None: + """ + if model list exist this custom model, switch the custom model credential. + if model list not exist this custom model, use the credential to add a new custom model record. + + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + # validate custom model config + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + if not provider_model_record: + # create provider model record + provider_model_record = ProviderModel( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_name=model, + model_type=model_type.to_origin_model_type(), + credential_id=credential_id, + ) + else: + if provider_model_record.credential_id == credential_record.id: + raise ValueError("Can't add same credential") + provider_model_record.credential_id = credential_record.id + provider_model_record.updated_at = naive_utc_now() + session.add(provider_model_record) + session.commit() + + def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: + """ + switch the custom model credential. + + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + if not provider_model_record: + raise ValueError("The custom model record not found.") + + provider_model_record.credential_id = credential_record.id + provider_model_record.updated_at = naive_utc_now() + session.add(provider_model_record) + session.commit() + + def delete_custom_model(self, model_type: ModelType, model: str) -> None: + """ + Delete custom model. + :param model_type: model type + :param model: model name + :return: + """ + with Session(db.engine) as session: + # get provider model + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + # delete provider model + if provider_model_record: + session.delete(provider_model_record) + session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + + provider_model_credentials_cache.delete() + + def _get_provider_model_setting( + self, model_type: ModelType, model: str, session: Session + ) -> ProviderModelSetting | None: """ Get provider model setting. """ @@ -479,16 +1129,13 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - return ( - db.session.query(ProviderModelSetting) - .where( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name.in_(provider_names), - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, - ) - .first() + stmt = select(ProviderModelSetting).where( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name.in_(provider_names), + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, ) + return session.execute(stmt).scalars().first() def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ @@ -497,21 +1144,23 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = self._get_provider_model_setting(model_type, model) + with Session(db.engine) as session: + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - if model_setting: - model_setting.enabled = True - model_setting.updated_at = naive_utc_now() - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.enabled = True - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.enabled = True + model_setting.updated_at = naive_utc_now() + + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=True, + ) + session.add(model_setting) + session.commit() return model_setting @@ -522,21 +1171,22 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = self._get_provider_model_setting(model_type, model) + with Session(db.engine) as session: + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - if model_setting: - model_setting.enabled = False - model_setting.updated_at = naive_utc_now() - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.enabled = False - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.enabled = False + model_setting.updated_at = naive_utc_now() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=False, + ) + session.add(model_setting) + session.commit() return model_setting @@ -547,27 +1197,8 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - return self._get_provider_model_setting(model_type, model) - - def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]: - """ - Get load balancing config. - """ - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - - return ( - db.session.query(LoadBalancingModelConfig) - .where( - LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name.in_(provider_names), - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - ) - .first() - ) + with Session(db.engine) as session: + return self._get_provider_model_setting(model_type=model_type, model=model, session=session) def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ @@ -581,35 +1212,32 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - load_balancing_config_count = ( - db.session.query(LoadBalancingModelConfig) - .where( + with Session(db.engine) as session: + stmt = select(func.count(LoadBalancingModelConfig.id)).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) - .count() - ) + load_balancing_config_count = session.execute(stmt).scalar() or 0 + if load_balancing_config_count <= 1: + raise ValueError("Model load balancing configuration must be more than 1.") - if load_balancing_config_count <= 1: - raise ValueError("Model load balancing configuration must be more than 1.") + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - model_setting = self._get_provider_model_setting(model_type, model) - - if model_setting: - model_setting.load_balancing_enabled = True - model_setting.updated_at = naive_utc_now() - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.load_balancing_enabled = True - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.load_balancing_enabled = True + model_setting.updated_at = naive_utc_now() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=True, + ) + session.add(model_setting) + session.commit() return model_setting @@ -620,35 +1248,23 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - model_setting = ( - db.session.query(ProviderModelSetting) - .where( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name.in_(provider_names), - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, - ) - .first() - ) + with Session(db.engine) as session: + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - if model_setting: - model_setting.load_balancing_enabled = False - model_setting.updated_at = naive_utc_now() - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.load_balancing_enabled = False - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.load_balancing_enabled = False + model_setting.updated_at = naive_utc_now() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=False, + ) + session.add(model_setting) + session.commit() return model_setting @@ -664,7 +1280,7 @@ class ProviderConfiguration(BaseModel): # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) - def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None: + def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None: """ Get model schema """ @@ -673,7 +1289,7 @@ class ProviderConfiguration(BaseModel): provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) - def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: + def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None: """ Switch preferred provider type. :param provider_type: @@ -685,31 +1301,35 @@ class ProviderConfiguration(BaseModel): if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: return - # get preferred provider - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) + def _switch(s: Session) -> None: + # get preferred provider + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) - preferred_model_provider = ( - db.session.query(TenantPreferredModelProvider) - .where( + stmt = select(TenantPreferredModelProvider).where( TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.provider_name.in_(provider_names), ) - .first() - ) + preferred_model_provider = s.execute(stmt).scalars().first() - if preferred_model_provider: - preferred_model_provider.preferred_provider_type = provider_type.value + if preferred_model_provider: + preferred_model_provider.preferred_provider_type = provider_type.value + else: + preferred_model_provider = TenantPreferredModelProvider( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + preferred_provider_type=provider_type.value, + ) + s.add(preferred_model_provider) + s.commit() + + if session: + return _switch(session) else: - preferred_model_provider = TenantPreferredModelProvider() - preferred_model_provider.tenant_id = self.tenant_id - preferred_model_provider.provider_name = self.provider.provider - preferred_model_provider.preferred_provider_type = provider_type.value - db.session.add(preferred_model_provider) - - db.session.commit() + with Session(db.engine) as session: + return _switch(session) def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ @@ -973,14 +1593,24 @@ class ProviderConfiguration(BaseModel): status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE load_balancing_enabled = False + has_invalid_load_balancing_configs = False if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: status = ModelStatus.DISABLED - if len(model_setting.load_balancing_configs) > 1: + provider_model_lb_configs = [ + config + for config in model_setting.load_balancing_configs + if config.credential_source_type != "custom_model" + ] + + if len(provider_model_lb_configs) > 1: load_balancing_enabled = True + if any(config.name == "__delete__" for config in provider_model_lb_configs): + has_invalid_load_balancing_configs = True + provider_models.append( ModelWithProviderEntity( model=m.model, @@ -993,6 +1623,7 @@ class ProviderConfiguration(BaseModel): provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, + has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) ) @@ -1017,6 +1648,7 @@ class ProviderConfiguration(BaseModel): status = ModelStatus.ACTIVE load_balancing_enabled = False + has_invalid_load_balancing_configs = False if ( custom_model_schema.model_type in model_setting_map and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] @@ -1025,9 +1657,21 @@ class ProviderConfiguration(BaseModel): if model_setting.enabled is False: status = ModelStatus.DISABLED - if len(model_setting.load_balancing_configs) > 1: + custom_model_lb_configs = [ + config + for config in model_setting.load_balancing_configs + if config.credential_source_type != "provider" + ] + + if len(custom_model_lb_configs) > 1: load_balancing_enabled = True + if any(config.name == "__delete__" for config in custom_model_lb_configs): + has_invalid_load_balancing_configs = True + + if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials: + status = ModelStatus.CREDENTIAL_REMOVED + provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, @@ -1040,6 +1684,7 @@ class ProviderConfiguration(BaseModel): provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, + has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) ) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index a5a6e62bd7..1b87bffe57 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -69,6 +69,15 @@ class QuotaConfiguration(BaseModel): restrict_models: list[RestrictModel] = [] +class CredentialConfiguration(BaseModel): + """ + Model class for credential configuration. + """ + + credential_id: str + credential_name: str + + class SystemConfiguration(BaseModel): """ Model class for provider system configuration. @@ -86,6 +95,9 @@ class CustomProviderConfiguration(BaseModel): """ credentials: dict + current_credential_id: Optional[str] = None + current_credential_name: Optional[str] = None + available_credentials: list[CredentialConfiguration] = [] class CustomModelConfiguration(BaseModel): @@ -95,7 +107,10 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType - credentials: dict + credentials: dict | None + current_credential_id: Optional[str] = None + current_credential_name: Optional[str] = None + available_model_credentials: list[CredentialConfiguration] = [] # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -118,6 +133,7 @@ class ModelLoadBalancingConfiguration(BaseModel): id: str name: str credentials: dict + credential_source_type: str | None = None class ModelSettings(BaseModel): diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 8c1d171688..4afbf5eda6 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -570,5 +570,5 @@ class LLMGenerator: error = str(e) return {"error": f"Failed to generate code. Error: {error}"} except Exception as e: - logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e) + logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name"), exc_info=e) return {"error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index f8590b38f8..24cf69a50b 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -201,7 +201,7 @@ class ModelProviderFactory: return filtered_credentials def get_model_schema( - self, *, provider: str, model_type: ModelType, model: str, credentials: dict + self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None ) -> AIModelEntity | None: """ Get model schema diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 99bd0049c0..f079478798 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -100,14 +100,14 @@ class Moderation(Extensible, ABC): if not inputs_config.get("preset_response"): raise ValueError("inputs_config.preset_response is required") - if len(inputs_config.get("preset_response", 0)) > 100: + if len(inputs_config.get("preset_response", "0")) > 100: raise ValueError("inputs_config.preset_response must be less than 100 characters") if outputs_config_enabled: if not outputs_config.get("preset_response"): raise ValueError("outputs_config.preset_response is required") - if len(outputs_config.get("preset_response", 0)) > 100: + if len(outputs_config.get("preset_response", "0")) > 100: raise ValueError("outputs_config.preset_response must be less than 100 characters") diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index e8c9bed099..cf62dc6ab6 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -154,7 +154,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ workflow = app.workflow if not workflow: - raise ValueError("") + raise ValueError("unexpected app type") return WorkflowAppGenerator().generate( app_model=app, diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 39fec951bb..28a4ce0778 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -12,6 +12,7 @@ from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_entities import ( + CredentialConfiguration, CustomConfiguration, CustomModelConfiguration, CustomProviderConfiguration, @@ -40,7 +41,9 @@ from extensions.ext_redis import redis_client from models.provider import ( LoadBalancingModelConfig, Provider, + ProviderCredential, ProviderModel, + ProviderModelCredential, ProviderModelSetting, ProviderType, TenantDefaultModel, @@ -488,6 +491,61 @@ class ProviderManager: return provider_name_to_provider_load_balancing_model_configs_dict + @staticmethod + def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]: + """ + Get provider all credentials. + + :param tenant_id: workspace id + :param provider_name: provider name + :return: + """ + with Session(db.engine, expire_on_commit=False) as session: + stmt = ( + select(ProviderCredential) + .where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name) + .order_by(ProviderCredential.created_at.desc()) + ) + + available_credentials = session.scalars(stmt).all() + + return [ + CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) + for credential in available_credentials + ] + + @staticmethod + def get_provider_model_available_credentials( + tenant_id: str, provider_name: str, model_name: str, model_type: str + ) -> list[CredentialConfiguration]: + """ + Get provider custom model all credentials. + + :param tenant_id: workspace id + :param provider_name: provider name + :param model_name: model name + :param model_type: model type + :return: + """ + with Session(db.engine, expire_on_commit=False) as session: + stmt = ( + select(ProviderModelCredential) + .where( + ProviderModelCredential.tenant_id == tenant_id, + ProviderModelCredential.provider_name == provider_name, + ProviderModelCredential.model_name == model_name, + ProviderModelCredential.model_type == model_type, + ) + .order_by(ProviderModelCredential.created_at.desc()) + ) + + available_credentials = session.scalars(stmt).all() + + return [ + CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) + for credential in available_credentials + ] + @staticmethod def _init_trial_provider_records( tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] @@ -590,9 +648,6 @@ class ProviderManager: if provider_record.provider_type == ProviderType.SYSTEM.value: continue - if not provider_record.encrypted_config: - continue - custom_provider_record = provider_record # Get custom provider credentials @@ -611,8 +666,8 @@ class ProviderManager: try: # fix origin data if custom_provider_record.encrypted_config is None: - raise ValueError("No credentials found") - if not custom_provider_record.encrypted_config.startswith("{"): + provider_credentials = {} + elif not custom_provider_record.encrypted_config.startswith("{"): provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) @@ -637,7 +692,14 @@ class ProviderManager: else: provider_credentials = cached_provider_credentials - custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) + custom_provider_configuration = CustomProviderConfiguration( + credentials=provider_credentials, + current_credential_name=custom_provider_record.credential_name, + current_credential_id=custom_provider_record.credential_id, + available_credentials=self.get_provider_available_credentials( + tenant_id, custom_provider_record.provider_name + ), + ) # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( @@ -649,8 +711,12 @@ class ProviderManager: # Get custom provider model credentials custom_model_configurations = [] for provider_model_record in provider_model_records: - if not provider_model_record.encrypted_config: - continue + available_model_credentials = self.get_provider_model_available_credentials( + tenant_id, + provider_model_record.provider_name, + provider_model_record.model_name, + provider_model_record.model_type, + ) provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL @@ -659,7 +725,7 @@ class ProviderManager: # Get cached provider model credentials cached_provider_model_credentials = provider_model_credentials_cache.get() - if not cached_provider_model_credentials: + if not cached_provider_model_credentials and provider_model_record.encrypted_config: try: provider_model_credentials = json.loads(provider_model_record.encrypted_config) except JSONDecodeError: @@ -688,6 +754,9 @@ class ProviderManager: model=provider_model_record.model_name, model_type=ModelType.value_of(provider_model_record.model_type), credentials=provider_model_credentials, + current_credential_id=provider_model_record.credential_id, + current_credential_name=provider_model_record.credential_name, + available_model_credentials=available_model_credentials, ) ) @@ -899,6 +968,18 @@ class ProviderManager: load_balancing_model_config.model_name == provider_model_setting.model_name and load_balancing_model_config.model_type == provider_model_setting.model_type ): + if load_balancing_model_config.name == "__delete__": + # to calculate current model whether has invalidate lb configs + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials={}, + credential_source_type=load_balancing_model_config.credential_source_type, + ) + ) + continue + if not load_balancing_model_config.enabled: continue @@ -955,6 +1036,7 @@ class ProviderManager: id=load_balancing_model_config.id, name=load_balancing_model_config.name, credentials=provider_model_credentials, + credential_source_type=load_balancing_model_config.credential_source_type, ) ) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 303c3fe31c..d668298373 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -188,14 +188,17 @@ class OracleVector(BaseVector): def text_exists(self, id: str) -> bool: with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) + cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,)) return cur.fetchone() is not None conn.close() def get_by_ids(self, ids: list[str]) -> list[Document]: + if not ids: + return [] with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) + placeholders = ", ".join(f":{i + 1}" for i in range(len(ids))) + cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids) docs = [] for record in cur: docs.append(Document(page_content=record[1], metadata=record[0])) @@ -208,14 +211,15 @@ class OracleVector(BaseVector): return with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) + placeholders = ", ".join(f":{i + 1}" for i in range(len(ids))) + cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids) conn.commit() conn.close() def delete_by_metadata_field(self, key: str, value: str) -> None: with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,)) conn.commit() conn.close() @@ -227,12 +231,20 @@ class OracleVector(BaseVector): :param top_k: The number of nearest neighbors to return, default is 5. :return: List of Documents that are nearest to the query vector. """ + # Validate and sanitize top_k to prevent SQL injection top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000: + top_k = 4 # Use default if invalid + document_ids_filter = kwargs.get("document_ids_filter") where_clause = "" + params = [numpy.array(query_vector)] + if document_ids_filter: - document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) - where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" + placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter))) + where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})" + params.extend(document_ids_filter) + with self._get_connection() as conn: conn.inputtypehandler = self.input_type_handler conn.outputtypehandler = self.output_type_handler @@ -241,7 +253,7 @@ class OracleVector(BaseVector): f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine) AS distance FROM {self.table_name} {where_clause} ORDER BY distance fetch first {top_k} rows only""", - [numpy.array(query_vector)], + params, ) docs = [] score_threshold = float(kwargs.get("score_threshold") or 0.0) @@ -259,7 +271,10 @@ class OracleVector(BaseVector): import nltk # type: ignore from nltk.corpus import stopwords # type: ignore + # Validate and sanitize top_k to prevent SQL injection top_k = kwargs.get("top_k", 5) + if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000: + top_k = 5 # Use default if invalid # just not implement fetch by score_threshold now, may be later score_threshold = float(kwargs.get("score_threshold") or 0.0) if len(query) > 0: @@ -297,14 +312,21 @@ class OracleVector(BaseVector): with conn.cursor() as cur: document_ids_filter = kwargs.get("document_ids_filter") where_clause = "" + params: dict[str, Any] = {"kk": " ACCUM ".join(entities)} + if document_ids_filter: - document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) - where_clause = f" AND metadata->>'document_id' in ({document_ids}) " + placeholders = [] + for i, doc_id in enumerate(document_ids_filter): + param_name = f"doc_id_{i}" + placeholders.append(f":{param_name}") + params[param_name] = doc_id + where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) " + cur.execute( f"""select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :kk, 1) > 0 {where_clause} order by score(1) desc fetch first {top_k} rows only""", - kk=" ACCUM ".join(entities), + params, ) docs = [] for record in cur: diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index f8a851a246..e5492cb7f3 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -83,14 +83,14 @@ class TiDBVector(BaseVector): self._dimension = 1536 def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - logger.info("create collection and add texts, collection_name: " + self._collection_name) + logger.info("create collection and add texts, collection_name: %s", self._collection_name) self._create_collection(len(embeddings[0])) self.add_texts(texts, embeddings) self._dimension = len(embeddings[0]) pass def _create_collection(self, dimension: int): - logger.info("_create_collection, collection_name " + self._collection_name) + logger.info("_create_collection, collection_name %s", self._collection_name) lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): collection_exist_cache_key = f"vector_indexing_{self._collection_name}" diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 9848a28384..27b635a0cc 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -95,7 +95,7 @@ class CacheEmbedding(Embeddings): db.session.rollback() except Exception as ex: db.session.rollback() - logger.exception("Failed to embed documents: %s") + logger.exception("Failed to embed documents") raise ex return text_embeddings diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index cbc96037bf..80de746e29 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -39,9 +39,16 @@ class WeightRerankRunner(BaseRerankRunner): unique_documents = [] doc_ids = set() for document in documents: - if document.metadata is not None and document.metadata["doc_id"] not in doc_ids: + if ( + document.provider == "dify" + and document.metadata is not None + and document.metadata["doc_id"] not in doc_ids + ): doc_ids.add(document.metadata["doc_id"]) unique_documents.append(document) + else: + if document not in unique_documents: + unique_documents.append(document) documents = unique_documents diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 3c0bfa5240..97342640f5 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -275,35 +275,30 @@ class ApiTool(Tool): if files: headers.pop("Content-Type", None) - if method in { - "get", - "head", - "post", - "put", - "delete", - "patch", - "options", - "GET", - "POST", - "PUT", - "PATCH", - "DELETE", - "HEAD", - "OPTIONS", - }: - response: httpx.Response = getattr(ssrf_proxy, method.lower())( - url, - params=params, - headers=headers, - cookies=cookies, - data=body, - files=files, - timeout=API_TOOL_DEFAULT_TIMEOUT, - follow_redirects=True, - ) - return response - else: + _METHOD_MAP = { + "get": ssrf_proxy.get, + "head": ssrf_proxy.head, + "post": ssrf_proxy.post, + "put": ssrf_proxy.put, + "delete": ssrf_proxy.delete, + "patch": ssrf_proxy.patch, + } + method_lc = method.lower() + if method_lc not in _METHOD_MAP: raise ValueError(f"Invalid http method {method}") + response: httpx.Response = _METHOD_MAP[ + method_lc + ]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926 + url, + params=params, + headers=headers, + cookies=cookies, + data=body, + files=files, + timeout=API_TOOL_DEFAULT_TIMEOUT, + follow_redirects=True, + ) + return response def _convert_body_property_any_of( self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index dfc2a0000b..ecfbec7030 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -737,7 +737,7 @@ class LLMNode(BaseNode): and isinstance(prompt_messages[-1], UserPromptMessage) and isinstance(prompt_messages[-1].content, list) ): - prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts) + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) else: prompt_messages.append(UserPromptMessage(content=file_prompts)) diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index e21092349e..ddef26faaf 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -31,7 +31,7 @@ if [[ "${MODE}" == "worker" ]]; then fi exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ - --max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ + --max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} elif [[ "${MODE}" == "beat" ]]; then diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index f01dd58900..0ae6d9ee0d 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -85,6 +85,7 @@ def handle(sender: Message, **kwargs): values=_ProviderUpdateValues(last_used=current_time), description="basic_last_used_update", ) + logging.info("provider used, tenant_id=%s, provider_name=%s", tenant_id, provider_name) updates_to_perform.append(basic_update) # 2. Check if we need to deduct quota (system provider only) @@ -186,6 +187,8 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation] if not updates_to_perform: return + updates_to_perform = sorted(updates_to_perform, key=lambda i: (i.filters.tenant_id, i.filters.provider_name)) + # Use SQLAlchemy's context manager for transaction management # This automatically handles commit/rollback with Session(db.engine) as session, session.begin(): @@ -212,10 +215,13 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation] # Prepare values dict for SQLAlchemy update update_values = {} - if values.last_used is not None: - update_values["last_used"] = values.last_used + # updateing to `last_used` is removed due to performance reason. + # ref: https://github.com/langgenius/dify/issues/24526 if values.quota_used is not None: update_values["quota_used"] = values.quota_used + # Skip the current update operation if no updates are required. + if not update_values: + continue # Build and execute the update statement stmt = update(Provider).where(*where_conditions).values(**update_values) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 9e5c71fb1d..cd01a31068 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -21,7 +21,7 @@ login_manager = flask_login.LoginManager() def load_user_from_request(request_from_flask_login): """Load user based on the request.""" # Skip authentication for documentation endpoints - if request.path.endswith("/docs") or request.path.endswith("/swagger.json"): + if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")): return None auth_header = request.headers.get("Authorization", "") diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 93f6e447dc..27ab505376 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -24,8 +24,6 @@ integrate_notion_info_list_fields = { "notion_info": fields.List(fields.Nested(integrate_workspace_fields)), } -integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String} - integrate_page_fields = { "page_name": fields.String, "page_id": fields.String, diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 95d13cd0e6..a630a97fd6 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -3,11 +3,12 @@ import sys from collections.abc import Mapping from typing import Any -from flask import current_app, got_request_exception +from flask import Blueprint, Flask, current_app, got_request_exception from flask_restx import Api from werkzeug.exceptions import HTTPException from werkzeug.http import HTTP_STATUS_CODES +from configs import dify_config from core.errors.error import AppInvokeQuotaExceededError @@ -106,6 +107,22 @@ def register_external_error_handlers(api: Api) -> None: class ExternalApi(Api): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + _authorizations = { + "Bearer": { + "type": "apiKey", + "in": "header", + "name": "Authorization", + "description": "Type: Bearer {your-api-key}", + } + } + + def __init__(self, app: Blueprint | Flask, *args, **kwargs): + kwargs.setdefault("authorizations", self._authorizations) + kwargs.setdefault("security", "Bearer") + kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED + kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False + + # manual separate call on construction and init_app to ensure configs in kwargs effective + super().__init__(app=None, *args, **kwargs) # type: ignore + self.init_app(app, **kwargs) register_external_error_handlers(self) diff --git a/api/libs/login.py b/api/libs/login.py index e3a7fe2948..711d16e3b9 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Any +from typing import Union, cast from flask import current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS # type: ignore @@ -11,7 +11,7 @@ from models.model import EndUser #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user -current_user: Any = LocalProxy(lambda: _get_user()) +current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) def login_required(func): @@ -52,7 +52,7 @@ def login_required(func): def decorated_view(*args, **kwargs): if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass - elif not current_user.is_authenticated: + elif current_user is not None and not current_user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore # flask 1.x compatibility diff --git a/api/migrations/versions/2025_08_09_1553-e8446f481c1e_add_provider_credential_pool_support.py b/api/migrations/versions/2025_08_09_1553-e8446f481c1e_add_provider_credential_pool_support.py new file mode 100644 index 0000000000..87b42346df --- /dev/null +++ b/api/migrations/versions/2025_08_09_1553-e8446f481c1e_add_provider_credential_pool_support.py @@ -0,0 +1,177 @@ +"""Add provider multi credential support + +Revision ID: e8446f481c1e +Revises: 8bcc02c9bd07 +Create Date: 2025-08-09 15:53:54.341341 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.sql import table, column +import uuid + +# revision identifiers, used by Alembic. +revision = 'e8446f481c1e' +down_revision = 'fa8b0fa6f407' +branch_labels = None +depends_on = None + + +def upgrade(): + # Create provider_credentials table + op.create_table('provider_credentials', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('credential_name', sa.String(length=255), nullable=False), + sa.Column('encrypted_config', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_credential_pkey') + ) + + # Create index for provider_credentials + with op.batch_alter_table('provider_credentials', schema=None) as batch_op: + batch_op.create_index('provider_credential_tenant_provider_idx', ['tenant_id', 'provider_name'], unique=False) + + # Add credential_id to providers table + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) + + # Add credential_id to load_balancing_model_configs table + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) + + migrate_existing_providers_data() + + # Remove encrypted_config column from providers table after migration + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.drop_column('encrypted_config') + + +def migrate_existing_providers_data(): + """migrate providers table data to provider_credentials""" + + # Define table structure for data manipulation + providers_table = table('providers', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()), + column('credential_id', models.types.StringUUID()), + ) + + provider_credential_table = table('provider_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()) + ) + + # Get database connection + conn = op.get_bind() + + # Query all existing providers data + existing_providers = conn.execute( + sa.select(providers_table.c.id, providers_table.c.tenant_id, + providers_table.c.provider_name, providers_table.c.encrypted_config, + providers_table.c.created_at, providers_table.c.updated_at) + .where(providers_table.c.encrypted_config.isnot(None)) + ).fetchall() + + # Iterate through each provider and insert into provider_credentials + for provider in existing_providers: + credential_id = str(uuid.uuid4()) + if not provider.encrypted_config or provider.encrypted_config.strip() == '': + continue + + # Insert into provider_credentials table + conn.execute( + provider_credential_table.insert().values( + id=credential_id, + tenant_id=provider.tenant_id, + provider_name=provider.provider_name, + credential_name='API_KEY1', # Use a default name + encrypted_config=provider.encrypted_config, + created_at=provider.created_at, + updated_at=provider.updated_at + ) + ) + + # Update original providers table, set credential_id + conn.execute( + providers_table.update() + .where(providers_table.c.id == provider.id) + .values( + credential_id=credential_id, + ) + ) + +def downgrade(): + # Re-add encrypted_config column to providers table + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) + + # Migrate data back from provider_credentials to providers + migrate_data_back_to_providers() + + # Remove credential_id columns + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.drop_column('credential_id') + + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.drop_column('credential_id') + + # Drop provider_credentials table + op.drop_table('provider_credentials') + + +def migrate_data_back_to_providers(): + """Migrate data back from provider_credentials to providers table for downgrade""" + + # Define table structure for data manipulation + providers_table = table('providers', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('encrypted_config', sa.Text()), + column('credential_id', models.types.StringUUID()), + ) + + provider_credential_table = table('provider_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', sa.Text()), + ) + + # Get database connection + conn = op.get_bind() + + # Query providers that have credential_id + providers_with_credentials = conn.execute( + sa.select(providers_table.c.id, providers_table.c.credential_id) + .where(providers_table.c.credential_id.isnot(None)) + ).fetchall() + + # For each provider, get the credential data and update providers table + for provider in providers_with_credentials: + credential = conn.execute( + sa.select(provider_credential_table.c.encrypted_config) + .where(provider_credential_table.c.id == provider.credential_id) + ).fetchone() + + if credential: + # Update providers table with encrypted_config from credential + conn.execute( + providers_table.update() + .where(providers_table.c.id == provider.id) + .values(encrypted_config=credential.encrypted_config) + ) \ No newline at end of file diff --git a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py new file mode 100644 index 0000000000..bec1a45404 --- /dev/null +++ b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py @@ -0,0 +1,186 @@ +"""Add provider model multi credential support + +Revision ID: 0e154742a5fa +Revises: e8446f481c1e +Create Date: 2025-08-13 16:05:42.657730 + +""" +import uuid + +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.sql import table, column + + +# revision identifiers, used by Alembic. +revision = '0e154742a5fa' +down_revision = 'e8446f481c1e' +branch_labels = None +depends_on = None + + +def upgrade(): + # Create provider_model_credentials table + op.create_table('provider_model_credentials', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('model_name', sa.String(length=255), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('credential_name', sa.String(length=255), nullable=False), + sa.Column('encrypted_config', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_model_credential_pkey') + ) + + # Create index for provider_model_credentials + with op.batch_alter_table('provider_model_credentials', schema=None) as batch_op: + batch_op.create_index('provider_model_credential_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_name', 'model_type'], unique=False) + + # Add credential_id to provider_models table + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) + + + # Add credential_source_type to load_balancing_model_configs table + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_source_type', sa.String(length=40), nullable=True)) + + # Migrate existing provider_models data + migrate_existing_provider_models_data() + + # Remove encrypted_config column from provider_models table after migration + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.drop_column('encrypted_config') + + +def migrate_existing_provider_models_data(): + """migrate provider_models table data to provider_model_credentials""" + + # Define table structure for data manipulation + provider_models_table = table('provider_models', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()), + column('credential_id', models.types.StringUUID()), + ) + + provider_model_credentials_table = table('provider_model_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()) + ) + + + # Get database connection + conn = op.get_bind() + + # Query all existing provider_models data with encrypted_config + existing_provider_models = conn.execute( + sa.select(provider_models_table.c.id, provider_models_table.c.tenant_id, + provider_models_table.c.provider_name, provider_models_table.c.model_name, + provider_models_table.c.model_type, provider_models_table.c.encrypted_config, + provider_models_table.c.created_at, provider_models_table.c.updated_at) + .where(provider_models_table.c.encrypted_config.isnot(None)) + ).fetchall() + + # Iterate through each provider_model and insert into provider_model_credentials + for provider_model in existing_provider_models: + if not provider_model.encrypted_config or provider_model.encrypted_config.strip() == '': + continue + + credential_id = str(uuid.uuid4()) + + # Insert into provider_model_credentials table + conn.execute( + provider_model_credentials_table.insert().values( + id=credential_id, + tenant_id=provider_model.tenant_id, + provider_name=provider_model.provider_name, + model_name=provider_model.model_name, + model_type=provider_model.model_type, + credential_name='API_KEY1', # Use a default name + encrypted_config=provider_model.encrypted_config, + created_at=provider_model.created_at, + updated_at=provider_model.updated_at + ) + ) + + # Update original provider_models table, set credential_id + conn.execute( + provider_models_table.update() + .where(provider_models_table.c.id == provider_model.id) + .values(credential_id=credential_id) + ) + + +def downgrade(): + # Re-add encrypted_config column to provider_models table + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) + + # Migrate data back from provider_model_credentials to provider_models + migrate_data_back_to_provider_models() + + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.drop_column('credential_id') + + # Remove credential_source_type column from load_balancing_model_configs + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.drop_column('credential_source_type') + + # Drop provider_model_credentials table + op.drop_table('provider_model_credentials') + + +def migrate_data_back_to_provider_models(): + """Migrate data back from provider_model_credentials to provider_models table for downgrade""" + + # Define table structure for data manipulation + provider_models_table = table('provider_models', + column('id', models.types.StringUUID()), + column('encrypted_config', sa.Text()), + column('credential_id', models.types.StringUUID()), + ) + + provider_model_credentials_table = table('provider_model_credentials', + column('id', models.types.StringUUID()), + column('encrypted_config', sa.Text()), + ) + + # Get database connection + conn = op.get_bind() + + # Query provider_models that have credential_id + provider_models_with_credentials = conn.execute( + sa.select(provider_models_table.c.id, provider_models_table.c.credential_id) + .where(provider_models_table.c.credential_id.isnot(None)) + ).fetchall() + + # For each provider_model, get the credential data and update provider_models table + for provider_model in provider_models_with_credentials: + credential = conn.execute( + sa.select(provider_model_credentials_table.c.encrypted_config) + .where(provider_model_credentials_table.c.id == provider_model.credential_id) + ).fetchone() + + if credential: + # Update provider_models table with encrypted_config from credential + conn.execute( + provider_models_table.update() + .where(provider_models_table.c.id == provider_model.id) + .values(encrypted_config=credential.encrypted_config) + ) diff --git a/api/models/provider.py b/api/models/provider.py index 4ea2c59fdb..e75b26fd31 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,5 +1,6 @@ from datetime import datetime from enum import Enum +from functools import cached_property from typing import Optional import sqlalchemy as sa @@ -7,6 +8,7 @@ from sqlalchemy import DateTime, String, func, text from sqlalchemy.orm import Mapped, mapped_column from .base import Base +from .engine import db from .types import StringUUID @@ -60,9 +62,9 @@ class Provider(Base): provider_type: Mapped[str] = mapped_column( String(40), nullable=False, server_default=text("'custom'::character varying") ) - encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) quota_type: Mapped[Optional[str]] = mapped_column( String(40), nullable=True, server_default=text("''::character varying") @@ -79,6 +81,21 @@ class Provider(Base): f" provider_type='{self.provider_type}')>" ) + @cached_property + def credential(self): + if self.credential_id: + return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first() + + @property + def credential_name(self): + credential = self.credential + return credential.credential_name if credential else None + + @property + def encrypted_config(self): + credential = self.credential + return credential.encrypted_config if credential else None + @property def token_is_set(self): """ @@ -116,11 +133,30 @@ class ProviderModel(Base): provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + @cached_property + def credential(self): + if self.credential_id: + return ( + db.session.query(ProviderModelCredential) + .where(ProviderModelCredential.id == self.credential_id) + .first() + ) + + @property + def credential_name(self): + credential = self.credential + return credential.credential_name if credential else None + + @property + def encrypted_config(self): + credential = self.credential + return credential.encrypted_config if credential else None + class TenantDefaultModel(Base): __tablename__ = "tenant_default_models" @@ -220,6 +256,56 @@ class LoadBalancingModelConfig(Base): model_type: Mapped[str] = mapped_column(String(40), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + + +class ProviderCredential(Base): + """ + Provider credential - stores multiple named credentials for each provider + """ + + __tablename__ = "provider_credentials" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="provider_credential_pkey"), + sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + credential_name: Mapped[str] = mapped_column(String(255), nullable=False) + encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + + +class ProviderModelCredential(Base): + """ + Provider model credential - stores multiple named credentials for each provider model + """ + + __tablename__ = "provider_model_credentials" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="provider_model_credential_pkey"), + sa.Index( + "provider_model_credential_tenant_provider_model_idx", + "tenant_id", + "provider_name", + "model_name", + "model_type", + ), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + credential_name: Mapped[str] = mapped_column(String(255), nullable=False) + encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 1141451011..63e6132b6a 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -45,6 +45,7 @@ def clean_unused_datasets_task(): plan_filter = config["plan_filter"] add_logs = config["add_logs"] + page = 1 while True: try: # Subquery for counting new documents @@ -86,12 +87,12 @@ def clean_unused_datasets_task(): .order_by(Dataset.created_at.desc()) ) - datasets = db.paginate(stmt, page=1, per_page=50) + datasets = db.paginate(stmt, page=page, per_page=50, error_out=False) except SQLAlchemyError: raise - if datasets.items is None or len(datasets.items) == 0: + if datasets is None or datasets.items is None or len(datasets.items) == 0: break for dataset in datasets: @@ -150,5 +151,7 @@ def clean_unused_datasets_task(): except Exception as e: click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red")) + page += 1 + end_at = time.perf_counter() click.echo(click.style(f"Cleaned unused dataset from db success latency: {end_at - start_at}", fg="green")) diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index bc385b2e22..056decda26 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -8,7 +8,12 @@ from core.entities.model_entities import ( ModelWithProviderEntity, ProviderModelWithStatusEntity, ) -from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration +from core.entities.provider_entities import ( + CredentialConfiguration, + CustomModelConfiguration, + ProviderQuotaType, + QuotaConfiguration, +) from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ( @@ -36,6 +41,10 @@ class CustomConfigurationResponse(BaseModel): """ status: CustomConfigurationStatus + current_credential_id: Optional[str] = None + current_credential_name: Optional[str] = None + available_credentials: Optional[list[CredentialConfiguration]] = None + custom_models: Optional[list[CustomModelConfiguration]] = None class SystemConfigurationResponse(BaseModel): diff --git a/api/services/errors/app_model_config.py b/api/services/errors/app_model_config.py index c0669ed231..bb5eb62b75 100644 --- a/api/services/errors/app_model_config.py +++ b/api/services/errors/app_model_config.py @@ -3,3 +3,7 @@ from services.errors.base import BaseServiceError class AppModelConfigBrokenError(BaseServiceError): pass + + +class ProviderNotFoundError(BaseServiceError): + pass diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index f8dd70c790..2145b4cdd5 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -17,7 +17,7 @@ from core.model_runtime.model_providers.model_provider_factory import ModelProvi from core.provider_manager import ProviderManager from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models.provider import LoadBalancingModelConfig +from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential logger = logging.getLogger(__name__) @@ -185,6 +185,7 @@ class ModelLoadBalancingService: "id": load_balancing_config.id, "name": load_balancing_config.name, "credentials": credentials, + "credential_id": load_balancing_config.credential_id, "enabled": load_balancing_config.enabled, "in_cooldown": in_cooldown, "ttl": ttl, @@ -280,7 +281,7 @@ class ModelLoadBalancingService: return inherit_config def update_load_balancing_configs( - self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] + self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str ) -> None: """ Update load balancing configurations. @@ -289,6 +290,7 @@ class ModelLoadBalancingService: :param model: model name :param model_type: model type :param configs: load balancing configs + :param config_from: predefined-model or custom-model :return: """ # Get all provider configurations of the current workspace @@ -327,8 +329,37 @@ class ModelLoadBalancingService: config_id = config.get("id") name = config.get("name") credentials = config.get("credentials") + credential_id = config.get("credential_id") enabled = config.get("enabled") + if credential_id: + credential_record: ProviderCredential | ProviderModelCredential | None = None + if config_from == "predefined-model": + credential_record = ( + db.session.query(ProviderCredential) + .filter_by( + id=credential_id, + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + ) + .first() + ) + else: + credential_record = ( + db.session.query(ProviderModelCredential) + .filter_by( + id=credential_id, + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_name=model, + model_type=model_type_enum.to_origin_model_type(), + ) + .first() + ) + if not credential_record: + raise ValueError(f"Provider credential with id {credential_id} not found") + name = credential_record.credential_name + if not name: raise ValueError("Invalid load balancing config name") @@ -346,11 +377,6 @@ class ModelLoadBalancingService: load_balancing_config = current_load_balancing_configs_dict[config_id] - # check duplicate name - for current_load_balancing_config in current_load_balancing_configs: - if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: - raise ValueError(f"Load balancing config name {name} already exists") - if credentials: if not isinstance(credentials, dict): raise ValueError("Invalid load balancing config credentials") @@ -377,39 +403,48 @@ class ModelLoadBalancingService: self._clear_credentials_cache(tenant_id, config_id) else: # create load balancing config - if name == "__inherit__": + if name in {"__inherit__", "__delete__"}: raise ValueError("Invalid load balancing config name") - # check duplicate name - for current_load_balancing_config in current_load_balancing_configs: - if current_load_balancing_config.name == name: - raise ValueError(f"Load balancing config name {name} already exists") + if credential_id: + credential_source = "provider" if config_from == "predefined-model" else "custom_model" + assert credential_record is not None + load_balancing_model_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_type=model_type_enum.to_origin_model_type(), + model_name=model, + name=credential_record.credential_name, + encrypted_config=credential_record.encrypted_config, + credential_id=credential_id, + credential_source_type=credential_source, + ) + else: + if not credentials: + raise ValueError("Invalid load balancing config credentials") - if not credentials: - raise ValueError("Invalid load balancing config credentials") + if not isinstance(credentials, dict): + raise ValueError("Invalid load balancing config credentials") - if not isinstance(credentials, dict): - raise ValueError("Invalid load balancing config credentials") + # validate custom provider config + credentials = self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type_enum, + model=model, + credentials=credentials, + validate=False, + ) - # validate custom provider config - credentials = self._custom_credentials_validate( - tenant_id=tenant_id, - provider_configuration=provider_configuration, - model_type=model_type_enum, - model=model, - credentials=credentials, - validate=False, - ) - - # create load balancing config - load_balancing_model_config = LoadBalancingModelConfig( - tenant_id=tenant_id, - provider_name=provider_configuration.provider.provider, - model_type=model_type_enum.to_origin_model_type(), - model_name=model, - name=name, - encrypted_config=json.dumps(credentials), - ) + # create load balancing config + load_balancing_model_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_type=model_type_enum.to_origin_model_type(), + model_name=model, + name=name, + encrypted_config=json.dumps(credentials), + ) db.session.add(load_balancing_model_config) db.session.commit() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 54197bf949..67c3f0d6b2 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -16,6 +16,7 @@ from services.entities.model_provider_entities import ( SimpleProviderEntityResponse, SystemConfigurationResponse, ) +from services.errors.app_model_config import ProviderNotFoundError logger = logging.getLogger(__name__) @@ -28,6 +29,29 @@ class ModelProviderService: def __init__(self) -> None: self.provider_manager = ProviderManager() + def _get_provider_configuration(self, tenant_id: str, provider: str): + """ + Get provider configuration or raise exception if not found. + + Args: + tenant_id: Workspace identifier + provider: Provider name + + Returns: + Provider configuration instance + + Raises: + ProviderNotFoundError: If provider doesn't exist + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = provider_configurations.get(provider) + + if not provider_configuration: + raise ProviderNotFoundError(f"Provider {provider} does not exist.") + + return provider_configuration + def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]: """ get provider list. @@ -46,6 +70,9 @@ class ModelProviderService: if model_type_entity not in provider_configuration.provider.supported_model_types: continue + provider_config = provider_configuration.custom_configuration.provider + model_config = provider_configuration.custom_configuration.models + provider_response = ProviderResponse( tenant_id=tenant_id, provider=provider_configuration.provider.provider, @@ -63,7 +90,11 @@ class ModelProviderService: custom_configuration=CustomConfigurationResponse( status=CustomConfigurationStatus.ACTIVE if provider_configuration.is_custom_configuration_available() - else CustomConfigurationStatus.NO_CONFIGURE + else CustomConfigurationStatus.NO_CONFIGURE, + current_credential_id=getattr(provider_config, "current_credential_id", None), + current_credential_name=getattr(provider_config, "current_credential_name", None), + available_credentials=getattr(provider_config, "available_credentials", []), + custom_models=model_config, ), system_configuration=SystemConfigurationResponse( enabled=provider_configuration.system_configuration.enabled, @@ -82,8 +113,8 @@ class ModelProviderService: For the model provider page, only supports passing in a single provider to query the list of supported models. - :param tenant_id: - :param provider: + :param tenant_id: workspace id + :param provider: provider name :return: """ # Get all provider configurations of the current workspace @@ -95,98 +126,111 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] - def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]: + def get_provider_credential( + self, tenant_id: str, provider: str, credential_id: Optional[str] = None + ) -> Optional[dict]: """ get provider credentials. - """ - provider_configurations = self.provider_manager.get_configurations(tenant_id) - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - return provider_configuration.get_custom_credentials(obfuscated=True) - - def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: - """ - validate provider credentials. - - :param tenant_id: - :param provider: - :param credentials: - """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - provider_configuration.custom_credentials_validate(credentials) - - def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: - """ - save custom provider config. :param tenant_id: workspace id :param provider: provider name - :param credentials: provider credentials + :param credential_id: credential id, if not provided, return current used credentials :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = self._get_provider_configuration(tenant_id, provider) + return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Add or update custom provider credentials. - provider_configuration.add_or_update_custom_credentials(credentials) - - def remove_provider_credentials(self, tenant_id: str, provider: str) -> None: + def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: """ - remove custom provider config. + validate provider credentials before saving. :param tenant_id: workspace id :param provider: provider name + :param credentials: provider credentials dict + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.validate_provider_credentials(credentials) + + def create_provider_credential( + self, tenant_id: str, provider: str, credentials: dict, credential_name: str + ) -> None: + """ + Create and save new provider credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param credentials: provider credentials dict + :param credential_name: credential name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.create_provider_credential(credentials, credential_name) - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Remove custom provider credentials. - provider_configuration.delete_custom_credentials() - - def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]: + def update_provider_credential( + self, + tenant_id: str, + provider: str, + credentials: dict, + credential_id: str, + credential_name: str, + ) -> None: """ - get model credentials. + update a saved provider credential (by credential_id). + + :param tenant_id: workspace id + :param provider: provider name + :param credentials: provider credentials dict + :param credential_id: credential id + :param credential_name: credential name + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.update_provider_credential( + credential_id=credential_id, + credentials=credentials, + credential_name=credential_name, + ) + + def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: + """ + remove a saved provider credential (by credential_id). + :param tenant_id: workspace id + :param provider: provider name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.delete_provider_credential(credential_id=credential_id) + + def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: + """ + :param tenant_id: workspace id + :param provider: provider name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.switch_active_provider_credential(credential_id=credential_id) + + def get_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None + ) -> Optional[dict]: + """ + Retrieve model-specific credentials. :param tenant_id: workspace id :param provider: provider name :param model_type: model type :param model: model name + :param credential_id: Optional credential ID, uses current if not provided :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Get model custom credentials from ProviderModel if exists - return provider_configuration.get_custom_model_credentials( - model_type=ModelType.value_of(model_type), model=model, obfuscated=True + provider_configuration = self._get_provider_configuration(tenant_id, provider) + return provider_configuration.get_custom_model_credential( # type: ignore + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) - def model_credentials_validate( + def validate_model_credentials( self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict ) -> None: """ @@ -196,49 +240,122 @@ class ModelProviderService: :param provider: provider name :param model_type: model type :param model: model name - :param credentials: model credentials + :param credentials: model credentials dict :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Validate model credentials - provider_configuration.custom_model_credentials_validate( + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.validate_custom_model_credentials( model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) - def save_model_credentials( - self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + def create_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str ) -> None: """ - save model credentials. + create and save model credentials. :param tenant_id: workspace id :param provider: provider name :param model_type: model type :param model: model name - :param credentials: model credentials + :param credentials: model credentials dict + :param credential_name: credential name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Add or update custom model credentials - provider_configuration.add_or_update_custom_model_credentials( - model_type=ModelType.value_of(model_type), model=model, credentials=credentials + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.create_custom_model_credential( + model_type=ModelType.value_of(model_type), + model=model, + credentials=credentials, + credential_name=credential_name, ) - def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: + def update_model_credential( + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + credentials: dict, + credential_id: str, + credential_name: str, + ) -> None: + """ + update model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credentials: model credentials dict + :param credential_id: credential id + :param credential_name: credential name + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.update_custom_model_credential( + model_type=ModelType.value_of(model_type), + model=model, + credentials=credentials, + credential_id=credential_id, + credential_name=credential_name, + ) + + def remove_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str + ) -> None: + """ + remove model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.delete_custom_model_credential( + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id + ) + + def switch_active_custom_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str + ) -> None: + """ + switch model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.switch_custom_model_credential( + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id + ) + + def add_model_credential_to_model_list( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str + ) -> None: + """ + add model credentials to model list. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.add_model_credential_to_model( + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id + ) + + def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: """ remove model credentials. @@ -248,16 +365,8 @@ class ModelProviderService: :param model: model name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Remove custom model credentials - provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model) def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: """ @@ -331,13 +440,7 @@ class ModelProviderService: :param model: model name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") + provider_configuration = self._get_provider_configuration(tenant_id, provider) # fetch credentials credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) @@ -424,17 +527,11 @@ class ModelProviderService: :param preferred_provider_type: preferred provider type :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = self._get_provider_configuration(tenant_id, provider) # Convert preferred_provider_type to ProviderType preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type) - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - # Switch preferred provider type provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) @@ -448,15 +545,7 @@ class ModelProviderService: :param model_type: model type :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Enable model + provider_configuration = self._get_provider_configuration(tenant_id, provider) provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: @@ -469,13 +558,5 @@ class ModelProviderService: :param model_type: model type :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Enable model + provider_configuration = self._get_provider_configuration(tenant_id, provider) provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 8b7d44c1e4..ee1ba2b25c 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -235,10 +235,17 @@ class TestModelProviderService: mock_provider_entity.provider_credential_schema = None mock_provider_entity.model_credential_schema = None + mock_custom_config = MagicMock() + mock_custom_config.provider.current_credential_id = "credential-123" + mock_custom_config.provider.current_credential_name = "test-credential" + mock_custom_config.provider.available_credentials = [] + mock_custom_config.models = [] + mock_provider_config = MagicMock() mock_provider_config.provider = mock_provider_entity mock_provider_config.preferred_provider_type = ProviderType.CUSTOM mock_provider_config.is_custom_configuration_available.return_value = True + mock_provider_config.custom_configuration = mock_custom_config mock_provider_config.system_configuration.enabled = True mock_provider_config.system_configuration.current_quota_type = "free" mock_provider_config.system_configuration.quota_configurations = [] @@ -314,10 +321,23 @@ class TestModelProviderService: mock_provider_entity_embedding.provider_credential_schema = None mock_provider_entity_embedding.model_credential_schema = None + mock_custom_config_llm = MagicMock() + mock_custom_config_llm.provider.current_credential_id = "credential-123" + mock_custom_config_llm.provider.current_credential_name = "test-credential" + mock_custom_config_llm.provider.available_credentials = [] + mock_custom_config_llm.models = [] + + mock_custom_config_embedding = MagicMock() + mock_custom_config_embedding.provider.current_credential_id = "credential-456" + mock_custom_config_embedding.provider.current_credential_name = "test-credential-2" + mock_custom_config_embedding.provider.available_credentials = [] + mock_custom_config_embedding.models = [] + mock_provider_config_llm = MagicMock() mock_provider_config_llm.provider = mock_provider_entity_llm mock_provider_config_llm.preferred_provider_type = ProviderType.CUSTOM mock_provider_config_llm.is_custom_configuration_available.return_value = True + mock_provider_config_llm.custom_configuration = mock_custom_config_llm mock_provider_config_llm.system_configuration.enabled = True mock_provider_config_llm.system_configuration.current_quota_type = "free" mock_provider_config_llm.system_configuration.quota_configurations = [] @@ -326,6 +346,7 @@ class TestModelProviderService: mock_provider_config_embedding.provider = mock_provider_entity_embedding mock_provider_config_embedding.preferred_provider_type = ProviderType.CUSTOM mock_provider_config_embedding.is_custom_configuration_available.return_value = True + mock_provider_config_embedding.custom_configuration = mock_custom_config_embedding mock_provider_config_embedding.system_configuration.enabled = True mock_provider_config_embedding.system_configuration.current_quota_type = "free" mock_provider_config_embedding.system_configuration.quota_configurations = [] @@ -497,20 +518,29 @@ class TestModelProviderService: } mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + # Expected result structure + expected_credentials = { + "credentials": { + "api_key": "sk-***123", + "base_url": "https://api.openai.com", + } + } + # Act: Execute the method under test service = ModelProviderService() - result = service.get_provider_credentials(tenant.id, "openai") + with patch.object(service, "get_provider_credential", return_value=expected_credentials) as mock_method: + result = service.get_provider_credential(tenant.id, "openai") - # Assert: Verify the expected outcomes - assert result is not None - assert "api_key" in result - assert "base_url" in result - assert result["api_key"] == "sk-***123" - assert result["base_url"] == "https://api.openai.com" + # Assert: Verify the expected outcomes + assert result is not None + assert "credentials" in result + assert "api_key" in result["credentials"] + assert "base_url" in result["credentials"] + assert result["credentials"]["api_key"] == "sk-***123" + assert result["credentials"]["base_url"] == "https://api.openai.com" - # Verify mock interactions - mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.get_custom_credentials.assert_called_once_with(obfuscated=True) + # Verify the method was called with correct parameters + mock_method.assert_called_once_with(tenant.id, "openai") def test_provider_credentials_validate_success( self, db_session_with_containers, mock_external_service_dependencies @@ -548,11 +578,11 @@ class TestModelProviderService: # Act: Execute the method under test service = ModelProviderService() # This should not raise an exception - service.provider_credentials_validate(tenant.id, "openai", test_credentials) + service.validate_provider_credentials(tenant.id, "openai", test_credentials) # Assert: Verify mock interactions mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.custom_credentials_validate.assert_called_once_with(test_credentials) + mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials) def test_provider_credentials_validate_invalid_provider( self, db_session_with_containers, mock_external_service_dependencies @@ -581,7 +611,7 @@ class TestModelProviderService: # Act & Assert: Execute the method under test and verify exception service = ModelProviderService() with pytest.raises(ValueError, match="Provider nonexistent does not exist."): - service.provider_credentials_validate(tenant.id, "nonexistent", test_credentials) + service.validate_provider_credentials(tenant.id, "nonexistent", test_credentials) # Verify mock interactions mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) @@ -817,22 +847,29 @@ class TestModelProviderService: } mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + # Expected result structure + expected_credentials = { + "credentials": { + "api_key": "sk-***123", + "base_url": "https://api.openai.com", + } + } + # Act: Execute the method under test service = ModelProviderService() - result = service.get_model_credentials(tenant.id, "openai", "llm", "gpt-4") + with patch.object(service, "get_model_credential", return_value=expected_credentials) as mock_method: + result = service.get_model_credential(tenant.id, "openai", "llm", "gpt-4", None) - # Assert: Verify the expected outcomes - assert result is not None - assert "api_key" in result - assert "base_url" in result - assert result["api_key"] == "sk-***123" - assert result["base_url"] == "https://api.openai.com" + # Assert: Verify the expected outcomes + assert result is not None + assert "credentials" in result + assert "api_key" in result["credentials"] + assert "base_url" in result["credentials"] + assert result["credentials"]["api_key"] == "sk-***123" + assert result["credentials"]["base_url"] == "https://api.openai.com" - # Verify mock interactions - mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.get_custom_model_credentials.assert_called_once_with( - model_type=ModelType.LLM, model="gpt-4", obfuscated=True - ) + # Verify the method was called with correct parameters + mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None) def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -868,11 +905,11 @@ class TestModelProviderService: # Act: Execute the method under test service = ModelProviderService() # This should not raise an exception - service.model_credentials_validate(tenant.id, "openai", "llm", "gpt-4", test_credentials) + service.validate_model_credentials(tenant.id, "openai", "llm", "gpt-4", test_credentials) # Assert: Verify mock interactions mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.custom_model_credentials_validate.assert_called_once_with( + mock_provider_configuration.validate_custom_model_credentials.assert_called_once_with( model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials ) @@ -909,12 +946,12 @@ class TestModelProviderService: # Act: Execute the method under test service = ModelProviderService() - service.save_model_credentials(tenant.id, "openai", "llm", "gpt-4", test_credentials) + service.create_model_credential(tenant.id, "openai", "llm", "gpt-4", test_credentials, "testname") # Assert: Verify mock interactions mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.add_or_update_custom_model_credentials.assert_called_once_with( - model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials + mock_provider_configuration.create_custom_model_credential.assert_called_once_with( + model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname" ) def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): @@ -942,17 +979,17 @@ class TestModelProviderService: # Create mock provider configuration with remove method mock_provider_configuration = MagicMock() - mock_provider_configuration.delete_custom_model_credentials.return_value = None + mock_provider_configuration.delete_custom_model_credential.return_value = None mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} # Act: Execute the method under test service = ModelProviderService() - service.remove_model_credentials(tenant.id, "openai", "llm", "gpt-4") + service.remove_model_credential(tenant.id, "openai", "llm", "gpt-4", "5540007c-b988-46e0-b1c7-9b5fb9f330d6") # Assert: Verify mock interactions mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.delete_custom_model_credentials.assert_called_once_with( - model_type=ModelType.LLM, model="gpt-4" + mock_provider_configuration.delete_custom_model_credential.assert_called_once_with( + model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6" ) def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies): diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py new file mode 100644 index 0000000000..2e18184aea --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -0,0 +1,1358 @@ +import json +import uuid +from datetime import UTC, datetime, timedelta +from unittest.mock import patch + +import pytest +from faker import Faker + +from core.workflow.entities.workflow_execution import WorkflowExecutionStatus +from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun +from models.enums import CreatorUserRole +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.workflow_app_service import WorkflowAppService + + +class TestWorkflowAppService: + """Integration tests for WorkflowAppService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns for app service + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Mock ModelManager for model configuration + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app with realistic data + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "workflow", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + return app, account + + def _create_test_tenant_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test tenant and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (tenant, account) - Created tenant and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + return tenant, account + + def _create_test_app(self, db_session_with_containers, tenant, account): + """ + Helper method to create a test app for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant: Tenant instance + account: Account instance + + Returns: + App: Created app instance + """ + fake = Faker() + + # Create app with realistic data + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "workflow", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + return app + + def _create_test_workflow_data(self, db_session_with_containers, app, account): + """ + Helper method to create test workflow data for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: App instance + account: Account instance + + Returns: + tuple: (workflow, workflow_run, workflow_app_log) - Created workflow entities + """ + fake = Faker() + + from extensions.ext_database import db + + # Create workflow + workflow = Workflow( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + _features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + ) + db.session.add(workflow) + db.session.commit() + + # Create workflow run + workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + inputs=json.dumps({"input1": "test_value"}), + outputs=json.dumps({"output1": "result_value"}), + status="succeeded", + elapsed_time=1.5, + total_tokens=100, + total_steps=3, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC), + finished_at=datetime.now(UTC), + ) + db.session.add(workflow_run) + db.session.commit() + + # Create workflow app log + workflow_app_log = WorkflowAppLog( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_app_log) + db.session.commit() + + return workflow, workflow_run, workflow_app_log + + def test_get_paginate_workflow_app_logs_basic_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful pagination of workflow app logs with basic parameters. + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + workflow, workflow_run, workflow_app_log = self._create_test_workflow_data( + db_session_with_containers, app, account + ) + + # Act: Execute the method under test + service = WorkflowAppService() + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=1, limit=20 + ) + + # Assert: Verify the expected outcomes + assert result is not None + assert result["page"] == 1 + assert result["limit"] == 20 + assert result["total"] == 1 + assert result["has_more"] is False + assert len(result["data"]) == 1 + + # Verify the returned data + log_entry = result["data"][0] + assert log_entry.id == workflow_app_log.id + assert log_entry.tenant_id == app.tenant_id + assert log_entry.app_id == app.id + assert log_entry.workflow_id == workflow.id + assert log_entry.workflow_run_id == workflow_run.id + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(workflow_app_log) + assert workflow_app_log.id is not None + + def test_get_paginate_workflow_app_logs_with_keyword_search( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow app logs pagination with keyword search functionality. + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + workflow, workflow_run, workflow_app_log = self._create_test_workflow_data( + db_session_with_containers, app, account + ) + + # Update workflow run with searchable content + from extensions.ext_database import db + + workflow_run.inputs = json.dumps({"search_term": "test_keyword", "input2": "other_value"}) + workflow_run.outputs = json.dumps({"result": "test_keyword_found", "status": "success"}) + db.session.commit() + + # Act: Execute the method under test with keyword search + service = WorkflowAppService() + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="test_keyword", page=1, limit=20 + ) + + # Assert: Verify keyword search results + assert result is not None + assert result["total"] == 1 + assert len(result["data"]) == 1 + + # Verify the returned data contains the searched keyword + log_entry = result["data"][0] + assert log_entry.workflow_run_id == workflow_run.id + + # Test with non-matching keyword + result_no_match = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="non_existent_keyword", page=1, limit=20 + ) + + assert result_no_match["total"] == 0 + assert len(result_no_match["data"]) == 0 + + def test_get_paginate_workflow_app_logs_with_status_filter( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow app logs pagination with status filtering. + """ + # Arrange: Create test data with different statuses + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + from extensions.ext_database import db + + # Create workflow + workflow = Workflow( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + _features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + ) + db.session.add(workflow) + db.session.commit() + + # Create workflow runs with different statuses + statuses = ["succeeded", "failed", "running", "stopped"] + workflow_runs = [] + workflow_app_logs = [] + + for i, status in enumerate(statuses): + workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + inputs=json.dumps({"input": f"test_{i}"}), + outputs=json.dumps({"output": f"result_{i}"}), + status=status, + elapsed_time=1.0 + i, + total_tokens=100 + i * 10, + total_steps=3, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i), + finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None, + ) + db.session.add(workflow_run) + db.session.commit() + + workflow_app_log = WorkflowAppLog( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i), + ) + db.session.add(workflow_app_log) + db.session.commit() + + workflow_runs.append(workflow_run) + workflow_app_logs.append(workflow_app_log) + + # Act & Assert: Test filtering by different statuses + service = WorkflowAppService() + + # Test succeeded status filter + result_succeeded = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + status=WorkflowExecutionStatus.SUCCEEDED, + page=1, + limit=20, + ) + assert result_succeeded["total"] == 1 + assert result_succeeded["data"][0].workflow_run.status == "succeeded" + + # Test failed status filter + result_failed = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, status=WorkflowExecutionStatus.FAILED, page=1, limit=20 + ) + assert result_failed["total"] == 1 + assert result_failed["data"][0].workflow_run.status == "failed" + + # Test running status filter + result_running = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, status=WorkflowExecutionStatus.RUNNING, page=1, limit=20 + ) + assert result_running["total"] == 1 + assert result_running["data"][0].workflow_run.status == "running" + + def test_get_paginate_workflow_app_logs_with_time_filtering( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow app logs pagination with time-based filtering. + """ + # Arrange: Create test data with different timestamps + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + from extensions.ext_database import db + + # Create workflow + workflow = Workflow( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + _features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + ) + db.session.add(workflow) + db.session.commit() + + # Create workflow runs with different timestamps + base_time = datetime.now(UTC) + timestamps = [ + base_time - timedelta(hours=3), # 3 hours ago + base_time - timedelta(hours=2), # 2 hours ago + base_time - timedelta(hours=1), # 1 hour ago + base_time, # now + ] + + workflow_runs = [] + workflow_app_logs = [] + + for i, timestamp in enumerate(timestamps): + workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + inputs=json.dumps({"input": f"test_{i}"}), + outputs=json.dumps({"output": f"result_{i}"}), + status="succeeded", + elapsed_time=1.0, + total_tokens=100, + total_steps=3, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=timestamp, + finished_at=timestamp + timedelta(minutes=1), + ) + db.session.add(workflow_run) + db.session.commit() + + workflow_app_log = WorkflowAppLog( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=timestamp, + ) + db.session.add(workflow_app_log) + db.session.commit() + + workflow_runs.append(workflow_run) + workflow_app_logs.append(workflow_app_log) + + # Act & Assert: Test time-based filtering + service = WorkflowAppService() + + # Test filtering logs created after 2 hours ago + result_after = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_at_after=base_time - timedelta(hours=2), + page=1, + limit=20, + ) + assert result_after["total"] == 3 # Should get logs from 2 hours ago, 1 hour ago, and now + + # Test filtering logs created before 1 hour ago + result_before = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_at_before=base_time - timedelta(hours=1), + page=1, + limit=20, + ) + assert result_before["total"] == 3 # Should get logs from 3 hours ago, 2 hours ago, and 1 hour ago + + # Test filtering logs within a time range + result_range = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_at_after=base_time - timedelta(hours=2), + created_at_before=base_time - timedelta(hours=1), + page=1, + limit=20, + ) + assert result_range["total"] == 2 # Should get logs from 2 hours ago and 1 hour ago + + def test_get_paginate_workflow_app_logs_with_pagination( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow app logs pagination with different page sizes and limits. + """ + # Arrange: Create test data with multiple logs + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + from extensions.ext_database import db + + # Create workflow + workflow = Workflow( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + _features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + ) + db.session.add(workflow) + db.session.commit() + + # Create 25 workflow runs and logs + total_logs = 25 + workflow_runs = [] + workflow_app_logs = [] + + for i in range(total_logs): + workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + inputs=json.dumps({"input": f"test_{i}"}), + outputs=json.dumps({"output": f"result_{i}"}), + status="succeeded", + elapsed_time=1.0, + total_tokens=100, + total_steps=3, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i), + finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), + ) + db.session.add(workflow_run) + db.session.commit() + + workflow_app_log = WorkflowAppLog( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i), + ) + db.session.add(workflow_app_log) + db.session.commit() + + workflow_runs.append(workflow_run) + workflow_app_logs.append(workflow_app_log) + + # Act & Assert: Test pagination + service = WorkflowAppService() + + # Test first page with limit 10 + result_page1 = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=1, limit=10 + ) + assert result_page1["page"] == 1 + assert result_page1["limit"] == 10 + assert result_page1["total"] == total_logs + assert result_page1["has_more"] is True + assert len(result_page1["data"]) == 10 + + # Test second page with limit 10 + result_page2 = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=2, limit=10 + ) + assert result_page2["page"] == 2 + assert result_page2["limit"] == 10 + assert result_page2["total"] == total_logs + assert result_page2["has_more"] is True + assert len(result_page2["data"]) == 10 + + # Test third page with limit 10 + result_page3 = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=3, limit=10 + ) + assert result_page3["page"] == 3 + assert result_page3["limit"] == 10 + assert result_page3["total"] == total_logs + assert result_page3["has_more"] is False + assert len(result_page3["data"]) == 5 # Remaining 5 logs + + # Test with larger limit + result_large_limit = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=1, limit=50 + ) + assert result_large_limit["page"] == 1 + assert result_large_limit["limit"] == 50 + assert result_large_limit["total"] == total_logs + assert result_large_limit["has_more"] is False + assert len(result_large_limit["data"]) == total_logs + + def test_get_paginate_workflow_app_logs_with_user_role_filtering( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow app logs pagination with user role and session filtering. + """ + # Arrange: Create test data with different user roles + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + from extensions.ext_database import db + + # Create workflow + workflow = Workflow( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + _features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + ) + db.session.add(workflow) + db.session.commit() + + # Create end user + end_user = EndUser( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + type="web", + is_anonymous=False, + session_id="test_session_123", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + db.session.add(end_user) + db.session.commit() + + # Create workflow runs and logs for both account and end user + workflow_runs = [] + workflow_app_logs = [] + + # Account user logs + for i in range(3): + workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + inputs=json.dumps({"input": f"account_test_{i}"}), + outputs=json.dumps({"output": f"account_result_{i}"}), + status="succeeded", + elapsed_time=1.0, + total_tokens=100, + total_steps=3, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i), + finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), + ) + db.session.add(workflow_run) + db.session.commit() + + workflow_app_log = WorkflowAppLog( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i), + ) + db.session.add(workflow_app_log) + db.session.commit() + + workflow_runs.append(workflow_run) + workflow_app_logs.append(workflow_app_log) + + # End user logs + for i in range(2): + workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + inputs=json.dumps({"input": f"end_user_test_{i}"}), + outputs=json.dumps({"output": f"end_user_result_{i}"}), + status="succeeded", + elapsed_time=1.0, + total_tokens=100, + total_steps=3, + created_by_role=CreatorUserRole.END_USER.value, + created_by=end_user.id, + created_at=datetime.now(UTC) + timedelta(minutes=i + 10), + finished_at=datetime.now(UTC) + timedelta(minutes=i + 11), + ) + db.session.add(workflow_run) + db.session.commit() + + workflow_app_log = WorkflowAppLog( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + created_from="web-app", + created_by_role=CreatorUserRole.END_USER.value, + created_by=end_user.id, + created_at=datetime.now(UTC) + timedelta(minutes=i + 10), + ) + db.session.add(workflow_app_log) + db.session.commit() + + workflow_runs.append(workflow_run) + workflow_app_logs.append(workflow_app_log) + + # Act & Assert: Test user role filtering + service = WorkflowAppService() + + # Test filtering by end user session ID + result_session_filter = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_end_user_session_id="test_session_123", + page=1, + limit=20, + ) + assert result_session_filter["total"] == 2 + assert all(log.created_by_role == CreatorUserRole.END_USER.value for log in result_session_filter["data"]) + + # Test filtering by account email + result_account_filter = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, created_by_account=account.email, page=1, limit=20 + ) + assert result_account_filter["total"] == 3 + assert all(log.created_by_role == CreatorUserRole.ACCOUNT.value for log in result_account_filter["data"]) + + # Test filtering by non-existent session ID + result_no_session = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_end_user_session_id="non_existent_session", + page=1, + limit=20, + ) + assert result_no_session["total"] == 0 + + # Test filtering by non-existent account email + result_no_account = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account="nonexistent@example.com", + page=1, + limit=20, + ) + assert result_no_account["total"] == 0 + + def test_get_paginate_workflow_app_logs_with_uuid_keyword_search( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow app logs pagination with UUID keyword search functionality. + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + from extensions.ext_database import db + + # Create workflow + workflow = Workflow( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + _features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + ) + db.session.add(workflow) + db.session.commit() + + # Create workflow run with specific UUID + workflow_run_id = str(uuid.uuid4()) + workflow_run = WorkflowRun( + id=workflow_run_id, + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + inputs=json.dumps({"input": "test_input"}), + outputs=json.dumps({"output": "test_output"}), + status="succeeded", + elapsed_time=1.0, + total_tokens=100, + total_steps=3, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC), + finished_at=datetime.now(UTC) + timedelta(minutes=1), + ) + db.session.add(workflow_run) + db.session.commit() + + # Create workflow app log + workflow_app_log = WorkflowAppLog( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_app_log) + db.session.commit() + + # Act & Assert: Test UUID keyword search + service = WorkflowAppService() + + # Test searching by workflow run UUID + result_uuid_search = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword=workflow_run_id, page=1, limit=20 + ) + assert result_uuid_search["total"] == 1 + assert result_uuid_search["data"][0].workflow_run_id == workflow_run_id + + # Test searching by partial UUID (should not match) + partial_uuid = workflow_run_id[:8] + result_partial_uuid = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword=partial_uuid, page=1, limit=20 + ) + assert result_partial_uuid["total"] == 0 + + # Test searching by invalid UUID format + invalid_uuid = "invalid-uuid-format" + result_invalid_uuid = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword=invalid_uuid, page=1, limit=20 + ) + assert result_invalid_uuid["total"] == 0 + + def test_get_paginate_workflow_app_logs_with_edge_cases( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow app logs pagination with edge cases and boundary conditions. + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + from extensions.ext_database import db + + # Create workflow + workflow = Workflow( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + _features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + ) + db.session.add(workflow) + db.session.commit() + + # Create workflow run with edge case data + workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + inputs=json.dumps({"input": "test_input"}), + outputs=json.dumps({"output": "test_output"}), + status="succeeded", + elapsed_time=0.0, # Edge case: 0 elapsed time + total_tokens=0, # Edge case: 0 tokens + total_steps=0, # Edge case: 0 steps + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC), + finished_at=datetime.now(UTC), + ) + db.session.add(workflow_run) + db.session.commit() + + # Create workflow app log + workflow_app_log = WorkflowAppLog( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_app_log) + db.session.commit() + + # Act & Assert: Test edge cases + service = WorkflowAppService() + + # Test with page 1 (normal case) + result_page_one = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=1, limit=20 + ) + assert result_page_one["page"] == 1 + assert result_page_one["total"] == 1 + + # Test with very large limit + result_large_limit = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=1, limit=10000 + ) + assert result_large_limit["limit"] == 10000 + assert result_large_limit["total"] == 1 + + # Test with limit 0 (should return empty result) + result_zero_limit = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=1, limit=0 + ) + assert result_zero_limit["limit"] == 0 + assert result_zero_limit["total"] == 1 + assert len(result_zero_limit["data"]) == 0 + + # Test with very high page number (should return empty result) + result_high_page = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=999999, limit=20 + ) + assert result_high_page["page"] == 999999 + assert result_high_page["total"] == 1 + assert len(result_high_page["data"]) == 0 + assert result_high_page["has_more"] is False + + def test_get_paginate_workflow_app_logs_with_empty_results( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow app logs pagination with empty results and no data scenarios. + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Act & Assert: Test empty results + service = WorkflowAppService() + + # Test with no workflow logs + result_no_logs = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=1, limit=20 + ) + assert result_no_logs["page"] == 1 + assert result_no_logs["limit"] == 20 + assert result_no_logs["total"] == 0 + assert result_no_logs["has_more"] is False + assert len(result_no_logs["data"]) == 0 + + # Test with status filter that matches no logs + result_no_status_match = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, status=WorkflowExecutionStatus.FAILED, page=1, limit=20 + ) + assert result_no_status_match["total"] == 0 + assert len(result_no_status_match["data"]) == 0 + + # Test with keyword that matches no logs + result_no_keyword_match = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="nonexistent_keyword", page=1, limit=20 + ) + assert result_no_keyword_match["total"] == 0 + assert len(result_no_keyword_match["data"]) == 0 + + # Test with time filter that matches no logs + future_time = datetime.now(UTC) + timedelta(days=1) + result_future_time = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, created_at_after=future_time, page=1, limit=20 + ) + assert result_future_time["total"] == 0 + assert len(result_future_time["data"]) == 0 + + # Test with end user session that doesn't exist + result_no_session = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_end_user_session_id="nonexistent_session", + page=1, + limit=20, + ) + assert result_no_session["total"] == 0 + assert len(result_no_session["data"]) == 0 + + # Test with account email that doesn't exist + result_no_account = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account="nonexistent@example.com", + page=1, + limit=20, + ) + assert result_no_account["total"] == 0 + assert len(result_no_account["data"]) == 0 + + def test_get_paginate_workflow_app_logs_with_complex_query_combinations( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow app logs pagination with complex query combinations. + """ + # Arrange: Create test data with various combinations + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) + + # Create multiple logs with different characteristics + logs_data = [] + for i in range(5): + status = "succeeded" if i % 2 == 0 else "failed" + workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status=status, + inputs=json.dumps({"input": f"test_input_{i}"}), + outputs=json.dumps({"output": f"test_output_{i}"}) if status == "succeeded" else None, + error=json.dumps({"error": f"test_error_{i}"}) if status == "failed" else None, + elapsed_time=1.5, + total_tokens=100, + total_steps=3, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i), + finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status == "succeeded" else None, + ) + db_session_with_containers.add(workflow_run) + db_session_with_containers.flush() + + log = WorkflowAppLog( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i), + ) + db_session_with_containers.add(log) + logs_data.append((log, workflow_run)) + + db_session_with_containers.commit() + + service = WorkflowAppService() + + # Test complex combination: keyword + status + time range + pagination + result_complex = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + keyword="test_input_1", + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_after=datetime.now(UTC) - timedelta(minutes=10), + created_at_before=datetime.now(UTC) + timedelta(minutes=10), + page=1, + limit=3, + ) + + # Should find logs matching all criteria + assert result_complex["total"] >= 0 # At least 0, could be more depending on timing + assert len(result_complex["data"]) <= 3 # Respects limit + + # Test combination: user role + keyword + status + result_user_keyword_status = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account=account.email, + keyword="test_input", + status=WorkflowExecutionStatus.FAILED, + page=1, + limit=20, + ) + + # Should find failed logs created by the account with "test_input" in inputs + assert result_user_keyword_status["total"] >= 0 + + # Test combination: time range + status + pagination with small limit + result_time_status_limit = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_at_after=datetime.now(UTC) - timedelta(minutes=10), + status=WorkflowExecutionStatus.SUCCEEDED, + page=1, + limit=2, + ) + + assert result_time_status_limit["total"] >= 0 + assert len(result_time_status_limit["data"]) <= 2 + + def test_get_paginate_workflow_app_logs_with_large_dataset_performance( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow app logs pagination with large dataset for performance validation. + """ + # Arrange: Create a larger dataset + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) + + # Create 50 logs to test performance with larger datasets + logs_data = [] + for i in range(50): + status = "succeeded" if i % 3 == 0 else "failed" if i % 3 == 1 else "running" + workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status=status, + inputs=json.dumps({"input": f"performance_test_input_{i}", "index": i}), + outputs=json.dumps({"output": f"performance_test_output_{i}"}) if status == "succeeded" else None, + error=json.dumps({"error": f"performance_test_error_{i}"}) if status == "failed" else None, + elapsed_time=1.5, + total_tokens=100, + total_steps=3, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i), + finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None, + ) + db_session_with_containers.add(workflow_run) + db_session_with_containers.flush() + + log = WorkflowAppLog( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i), + ) + db_session_with_containers.add(log) + logs_data.append((log, workflow_run)) + + db_session_with_containers.commit() + + service = WorkflowAppService() + + # Test performance with large dataset and pagination + import time + + start_time = time.time() + + result_large = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=1, limit=20 + ) + + end_time = time.time() + execution_time = end_time - start_time + + # Performance assertions + assert result_large["total"] == 51 # 50 new logs + 1 from _create_test_workflow_data + assert len(result_large["data"]) == 20 + assert execution_time < 5.0 # Should complete within 5 seconds + + # Test pagination through large dataset + result_page_2 = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=2, limit=20 + ) + + assert result_page_2["total"] == 51 # 50 new logs + 1 from _create_test_workflow_data + assert len(result_page_2["data"]) == 20 + assert result_page_2["page"] == 2 + + # Test last page + result_last_page = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, page=3, limit=20 + ) + + assert result_last_page["total"] == 51 # 50 new logs + 1 from _create_test_workflow_data + assert len(result_last_page["data"]) == 11 # Last page should have remaining items (10 + 1) + assert result_last_page["page"] == 3 + + def test_get_paginate_workflow_app_logs_with_tenant_isolation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow app logs pagination with proper tenant isolation. + """ + # Arrange: Create multiple tenants and apps + fake = Faker() + + # Create first tenant and app + tenant1, account1 = self._create_test_tenant_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + app1 = self._create_test_app(db_session_with_containers, tenant1, account1) + workflow1, _, _ = self._create_test_workflow_data(db_session_with_containers, app1, account1) + + # Create second tenant and app + tenant2, account2 = self._create_test_tenant_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + app2 = self._create_test_app(db_session_with_containers, tenant2, account2) + workflow2, _, _ = self._create_test_workflow_data(db_session_with_containers, app2, account2) + + # Create logs for both tenants + for i, (app, workflow, account) in enumerate([(app1, workflow1, account1), (app2, workflow2, account2)]): + for j in range(3): + workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"input": f"tenant_{i}_input_{j}"}), + outputs=json.dumps({"output": f"tenant_{i}_output_{j}"}), + elapsed_time=1.5, + total_tokens=100, + total_steps=3, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j), + finished_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j + 1), + ) + db_session_with_containers.add(workflow_run) + db_session_with_containers.flush() + + log = WorkflowAppLog( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j), + ) + db_session_with_containers.add(log) + + db_session_with_containers.commit() + + service = WorkflowAppService() + + # Test tenant isolation: tenant1 should only see its own logs + result_tenant1 = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app1, page=1, limit=20 + ) + + assert result_tenant1["total"] == 4 # 3 new logs + 1 from _create_test_workflow_data + for log in result_tenant1["data"]: + assert log.tenant_id == app1.tenant_id + assert log.app_id == app1.id + + # Test tenant isolation: tenant2 should only see its own logs + result_tenant2 = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app2, page=1, limit=20 + ) + + assert result_tenant2["total"] == 4 # 3 new logs + 1 from _create_test_workflow_data + for log in result_tenant2["data"]: + assert log.tenant_id == app2.tenant_id + assert log.app_id == app2.id + + # Test cross-tenant search should not work + result_cross_tenant = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app1, + keyword="tenant_1_input", # Search for tenant2's data from tenant1's context + page=1, + limit=20, + ) + + # Should not find tenant2's data when searching from tenant1's context + assert result_cross_tenant["total"] == 0 diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py new file mode 100644 index 0000000000..aefb4bf8b0 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -0,0 +1,134 @@ +"""Test authentication security to prevent user enumeration.""" + +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from flask_restx import Api + +import services.errors.account +from controllers.console.auth.error import AuthenticationFailedError +from controllers.console.auth.login import LoginApi +from controllers.console.error import AccountNotFound + + +class TestAuthenticationSecurity: + """Test authentication endpoints for security against user enumeration.""" + + def setup_method(self): + """Set up test fixtures.""" + self.app = Flask(__name__) + self.api = Api(self.app) + self.api.add_resource(LoginApi, "/login") + self.client = self.app.test_client() + self.app.config["TESTING"] = True + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.FeatureService.get_system_features") + @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") + @patch("controllers.console.auth.login.AccountService.authenticate") + @patch("controllers.console.auth.login.AccountService.send_reset_password_email") + @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) + @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + def test_login_invalid_email_with_registration_allowed( + self, mock_get_invitation, mock_send_email, mock_authenticate, mock_is_rate_limit, mock_features, mock_db + ): + """Test that invalid email sends reset password email when registration is allowed.""" + # Arrange + mock_is_rate_limit.return_value = False + mock_get_invitation.return_value = None + mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found") + mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists + mock_features.return_value.is_allow_register = True + mock_send_email.return_value = "token123" + + # Act + with self.app.test_request_context( + "/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"} + ): + login_api = LoginApi() + result = login_api.post() + + # Assert + assert result == {"result": "fail", "data": "token123", "code": "account_not_found"} + mock_send_email.assert_called_once_with(email="nonexistent@example.com", language="en-US") + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") + @patch("controllers.console.auth.login.AccountService.authenticate") + @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") + @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) + @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + def test_login_wrong_password_returns_error( + self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db + ): + """Test that wrong password returns AuthenticationFailedError.""" + # Arrange + mock_is_rate_limit.return_value = False + mock_get_invitation.return_value = None + mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password") + mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists + + # Act + with self.app.test_request_context( + "/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"} + ): + login_api = LoginApi() + + # Assert + with pytest.raises(AuthenticationFailedError) as exc_info: + login_api.post() + + assert exc_info.value.error_code == "authentication_failed" + assert exc_info.value.description == "Invalid email or password." + mock_add_rate_limit.assert_called_once_with("existing@example.com") + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.FeatureService.get_system_features") + @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") + @patch("controllers.console.auth.login.AccountService.authenticate") + @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) + @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + def test_login_invalid_email_with_registration_disabled( + self, mock_get_invitation, mock_authenticate, mock_is_rate_limit, mock_features, mock_db + ): + """Test that invalid email raises AccountNotFound when registration is disabled.""" + # Arrange + mock_is_rate_limit.return_value = False + mock_get_invitation.return_value = None + mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found") + mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists + mock_features.return_value.is_allow_register = False + + # Act + with self.app.test_request_context( + "/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"} + ): + login_api = LoginApi() + + # Assert + with pytest.raises(AccountNotFound) as exc_info: + login_api.post() + + assert exc_info.value.error_code == "account_not_found" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.FeatureService.get_system_features") + @patch("controllers.console.auth.login.AccountService.get_user_through_email") + @patch("controllers.console.auth.login.AccountService.send_reset_password_email") + def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db): + """Test that reset password returns success with token for existing accounts.""" + # Mock the setup check + mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists + + # Test with existing account + mock_get_user.return_value = MagicMock(email="existing@example.com") + mock_send_email.return_value = "token123" + + with self.app.test_request_context("/reset-password", method="POST", json={"email": "existing@example.com"}): + from controllers.console.auth.login import ResetPasswordSendEmailApi + + api = ResetPasswordSendEmailApi() + result = api.post() + + assert result == {"result": "success", "data": "token123"} diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py new file mode 100644 index 0000000000..75621ecb6a --- /dev/null +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -0,0 +1,308 @@ +from unittest.mock import Mock, patch + +import pytest + +from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus +from core.entities.provider_entities import ( + CustomConfiguration, + ModelSettings, + ProviderQuotaType, + QuotaConfiguration, + QuotaUnit, + RestrictModel, + SystemConfiguration, +) +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from models.provider import Provider, ProviderType + + +@pytest.fixture +def mock_provider_entity(): + """Mock provider entity with basic configuration""" + provider_entity = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"), + icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"), + icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"), + background="background.png", + help=None, + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + provider_credential_schema=None, + model_credential_schema=None, + ) + + return provider_entity + + +@pytest.fixture +def mock_system_configuration(): + """Mock system configuration""" + quota_config = QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=1000, + quota_used=0, + is_valid=True, + restrict_models=[RestrictModel(model="gpt-4", reason="Experimental", model_type=ModelType.LLM)], + ) + + system_config = SystemConfiguration( + enabled=True, + credentials={"openai_api_key": "test_key"}, + quota_configurations=[quota_config], + current_quota_type=ProviderQuotaType.TRIAL, + ) + + return system_config + + +@pytest.fixture +def mock_custom_configuration(): + """Mock custom configuration""" + custom_config = CustomConfiguration(provider=None, models=[]) + return custom_config + + +@pytest.fixture +def provider_configuration(mock_provider_entity, mock_system_configuration, mock_custom_configuration): + """Create a test provider configuration instance""" + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + return ProviderConfiguration( + tenant_id="test_tenant", + provider=mock_provider_entity, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=mock_system_configuration, + custom_configuration=mock_custom_configuration, + model_settings=[], + ) + + +class TestProviderConfiguration: + """Test cases for ProviderConfiguration class""" + + def test_get_current_credentials_system_provider_success(self, provider_configuration): + """Test successfully getting credentials from system provider""" + # Arrange + provider_configuration.using_provider_type = ProviderType.SYSTEM + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + # Assert + assert credentials == {"openai_api_key": "test_key"} + + def test_get_current_credentials_model_disabled(self, provider_configuration): + """Test getting credentials when model is disabled""" + # Arrange + model_setting = ModelSettings( + model="gpt-4", + model_type=ModelType.LLM, + enabled=False, + load_balancing_configs=[], + has_invalid_load_balancing_configs=False, + ) + provider_configuration.model_settings = [model_setting] + + # Act & Assert + with pytest.raises(ValueError, match="Model gpt-4 is disabled"): + provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + def test_get_current_credentials_custom_provider_with_models(self, provider_configuration): + """Test getting credentials from custom provider with model configurations""" + # Arrange + provider_configuration.using_provider_type = ProviderType.CUSTOM + + mock_model_config = Mock() + mock_model_config.model_type = ModelType.LLM + mock_model_config.model = "gpt-4" + mock_model_config.credentials = {"openai_api_key": "custom_key"} + provider_configuration.custom_configuration.models = [mock_model_config] + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + # Assert + assert credentials == {"openai_api_key": "custom_key"} + + def test_get_system_configuration_status_active(self, provider_configuration): + """Test getting active system configuration status""" + # Arrange + provider_configuration.system_configuration.enabled = True + + # Act + status = provider_configuration.get_system_configuration_status() + + # Assert + assert status == SystemConfigurationStatus.ACTIVE + + def test_get_system_configuration_status_unsupported(self, provider_configuration): + """Test getting unsupported system configuration status""" + # Arrange + provider_configuration.system_configuration.enabled = False + + # Act + status = provider_configuration.get_system_configuration_status() + + # Assert + assert status == SystemConfigurationStatus.UNSUPPORTED + + def test_get_system_configuration_status_quota_exceeded(self, provider_configuration): + """Test getting quota exceeded system configuration status""" + # Arrange + provider_configuration.system_configuration.enabled = True + quota_config = provider_configuration.system_configuration.quota_configurations[0] + quota_config.is_valid = False + + # Act + status = provider_configuration.get_system_configuration_status() + + # Assert + assert status == SystemConfigurationStatus.QUOTA_EXCEEDED + + def test_is_custom_configuration_available_with_provider(self, provider_configuration): + """Test custom configuration availability with provider credentials""" + # Arrange + mock_provider = Mock() + mock_provider.available_credentials = ["openai_api_key"] + provider_configuration.custom_configuration.provider = mock_provider + provider_configuration.custom_configuration.models = [] + + # Act + result = provider_configuration.is_custom_configuration_available() + + # Assert + assert result is True + + def test_is_custom_configuration_available_with_models(self, provider_configuration): + """Test custom configuration availability with model configurations""" + # Arrange + provider_configuration.custom_configuration.provider = None + provider_configuration.custom_configuration.models = [Mock()] + + # Act + result = provider_configuration.is_custom_configuration_available() + + # Assert + assert result is True + + def test_is_custom_configuration_available_false(self, provider_configuration): + """Test custom configuration not available""" + # Arrange + provider_configuration.custom_configuration.provider = None + provider_configuration.custom_configuration.models = [] + + # Act + result = provider_configuration.is_custom_configuration_available() + + # Assert + assert result is False + + @patch("core.entities.provider_configuration.Session") + def test_get_provider_record_found(self, mock_session, provider_configuration): + """Test getting provider record successfully""" + # Arrange + mock_provider = Mock(spec=Provider) + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_provider + + # Act + result = provider_configuration._get_provider_record(mock_session_instance) + + # Assert + assert result == mock_provider + + @patch("core.entities.provider_configuration.Session") + def test_get_provider_record_not_found(self, mock_session, provider_configuration): + """Test getting provider record when not found""" + # Arrange + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None + + # Act + result = provider_configuration._get_provider_record(mock_session_instance) + + # Assert + assert result is None + + def test_init_with_customizable_model_only( + self, mock_provider_entity, mock_system_configuration, mock_custom_configuration + ): + """Test initialization with customizable model only configuration""" + # Arrange + mock_provider_entity.configurate_methods = [ConfigurateMethod.CUSTOMIZABLE_MODEL] + + # Act + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + config = ProviderConfiguration( + tenant_id="test_tenant", + provider=mock_provider_entity, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=mock_system_configuration, + custom_configuration=mock_custom_configuration, + model_settings=[], + ) + + # Assert + assert ConfigurateMethod.PREDEFINED_MODEL in config.provider.configurate_methods + + def test_get_current_credentials_with_restricted_models(self, provider_configuration): + """Test getting credentials with model restrictions""" + # Arrange + provider_configuration.using_provider_type = ProviderType.SYSTEM + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-3.5-turbo") + + # Assert + assert credentials is not None + assert "openai_api_key" in credentials + + @patch("core.entities.provider_configuration.Session") + def test_get_specific_provider_credential_success(self, mock_session, provider_configuration): + """Test getting specific provider credential successfully""" + # Arrange + credential_id = "test_credential_id" + mock_credential = Mock() + mock_credential.encrypted_config = '{"openai_api_key": "encrypted_key"}' + + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_credential + + # Act + with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get: + mock_get.return_value = {"openai_api_key": "test_key"} + result = provider_configuration._get_specific_provider_credential(credential_id) + + # Assert + assert result == {"openai_api_key": "test_key"} + + @patch("core.entities.provider_configuration.Session") + def test_get_specific_provider_credential_not_found(self, mock_session, provider_configuration): + """Test getting specific provider credential when not found""" + # Arrange + credential_id = "nonexistent_credential_id" + + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None + + # Act & Assert + with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get: + mock_get.return_value = None + result = provider_configuration._get_specific_provider_credential(credential_id) + assert result is None + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + # Assert + assert credentials == {"openai_api_key": "test_key"} diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 90d5a6f15b..2dab394029 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -1,190 +1,185 @@ -# from core.entities.provider_entities import ModelSettings -# from core.model_runtime.entities.model_entities import ModelType -# from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -# from core.provider_manager import ProviderManager -# from models.provider import LoadBalancingModelConfig, ProviderModelSetting +import pytest + +from core.entities.provider_entities import ModelSettings +from core.model_runtime.entities.model_entities import ModelType +from core.provider_manager import ProviderManager +from models.provider import LoadBalancingModelConfig, ProviderModelSetting -# def test__to_model_settings(mocker): -# # Get all provider entities -# model_provider_factory = ModelProviderFactory("test_tenant") -# provider_entities = model_provider_factory.get_providers() +@pytest.fixture +def mock_provider_entity(mocker): + mock_entity = mocker.Mock() + mock_entity.provider = "openai" + mock_entity.configurate_methods = ["predefined-model"] + mock_entity.supported_model_types = [ModelType.LLM] -# provider_entity = None -# for provider in provider_entities: -# if provider.provider == "openai": -# provider_entity = provider + mock_entity.model_credential_schema = mocker.Mock() + mock_entity.model_credential_schema.credential_form_schemas = [] -# # Mocking the inputs -# provider_model_settings = [ -# ProviderModelSetting( -# id="id", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# enabled=True, -# load_balancing_enabled=True, -# ) -# ] -# load_balancing_model_configs = [ -# LoadBalancingModelConfig( -# id="id1", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="__inherit__", -# encrypted_config=None, -# enabled=True, -# ), -# LoadBalancingModelConfig( -# id="id2", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="first", -# encrypted_config='{"openai_api_key": "fake_key"}', -# enabled=True, -# ), -# ] - -# mocker.patch( -# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} -# ) - -# provider_manager = ProviderManager() - -# # Running the method -# result = provider_manager._to_model_settings(provider_entity, -# provider_model_settings, load_balancing_model_configs) - -# # Asserting that the result is as expected -# assert len(result) == 1 -# assert isinstance(result[0], ModelSettings) -# assert result[0].model == "gpt-4" -# assert result[0].model_type == ModelType.LLM -# assert result[0].enabled is True -# assert len(result[0].load_balancing_configs) == 2 -# assert result[0].load_balancing_configs[0].name == "__inherit__" -# assert result[0].load_balancing_configs[1].name == "first" + return mock_entity -# def test__to_model_settings_only_one_lb(mocker): -# # Get all provider entities -# model_provider_factory = ModelProviderFactory("test_tenant") -# provider_entities = model_provider_factory.get_providers() +def test__to_model_settings(mocker, mock_provider_entity): + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] -# provider_entity = None -# for provider in provider_entities: -# if provider.provider == "openai": -# provider_entity = provider + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) -# # Mocking the inputs -# provider_model_settings = [ -# ProviderModelSetting( -# id="id", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# enabled=True, -# load_balancing_enabled=True, -# ) -# ] -# load_balancing_model_configs = [ -# LoadBalancingModelConfig( -# id="id1", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="__inherit__", -# encrypted_config=None, -# enabled=True, -# ) -# ] + provider_manager = ProviderManager() -# mocker.patch( -# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} -# ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) -# provider_manager = ProviderManager() - -# # Running the method -# result = provider_manager._to_model_settings( -# provider_entity, provider_model_settings, load_balancing_model_configs) - -# # Asserting that the result is as expected -# assert len(result) == 1 -# assert isinstance(result[0], ModelSettings) -# assert result[0].model == "gpt-4" -# assert result[0].model_type == ModelType.LLM -# assert result[0].enabled is True -# assert len(result[0].load_balancing_configs) == 0 + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 2 + assert result[0].load_balancing_configs[0].name == "__inherit__" + assert result[0].load_balancing_configs[1].name == "first" -# def test__to_model_settings_lb_disabled(mocker): -# # Get all provider entities -# model_provider_factory = ModelProviderFactory("test_tenant") -# provider_entities = model_provider_factory.get_providers() +def test__to_model_settings_only_one_lb(mocker, mock_provider_entity): + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ) + ] -# provider_entity = None -# for provider in provider_entities: -# if provider.provider == "openai": -# provider_entity = provider + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) -# # Mocking the inputs -# provider_model_settings = [ -# ProviderModelSetting( -# id="id", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# enabled=True, -# load_balancing_enabled=False, -# ) -# ] -# load_balancing_model_configs = [ -# LoadBalancingModelConfig( -# id="id1", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="__inherit__", -# encrypted_config=None, -# enabled=True, -# ), -# LoadBalancingModelConfig( -# id="id2", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="first", -# encrypted_config='{"openai_api_key": "fake_key"}', -# enabled=True, -# ), -# ] + provider_manager = ProviderManager() -# mocker.patch( -# "core.helper.model_provider_cache.ProviderCredentialsCache.get", -# return_value={"openai_api_key": "fake_key"} -# ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) -# provider_manager = ProviderManager() + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0 -# # Running the method -# result = provider_manager._to_model_settings(provider_entity, -# provider_model_settings, load_balancing_model_configs) -# # Asserting that the result is as expected -# assert len(result) == 1 -# assert isinstance(result[0], ModelSettings) -# assert result[0].model == "gpt-4" -# assert result[0].model_type == ModelType.LLM -# assert result[0].enabled is True -# assert len(result[0].load_balancing_configs) == 0 +def test__to_model_settings_lb_disabled(mocker, mock_provider_entity): + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=False, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] + + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0 diff --git a/dev/start-worker b/dev/start-worker index 66e446c831..a2af04c01c 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -8,4 +8,4 @@ cd "$SCRIPT_DIR/.." uv --directory api run \ celery -A app.celery worker \ - -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage + -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation diff --git a/docker/.env.example b/docker/.env.example index 711898016e..c6ed2acb35 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1250,6 +1250,10 @@ QUEUE_MONITOR_ALERT_EMAILS= # Monitor interval in minutes, default is 30 minutes QUEUE_MONITOR_INTERVAL=30 +# Swagger UI configuration +SWAGGER_UI_ENABLED=true +SWAGGER_UI_PATH=/swagger-ui.html + # Celery schedule tasks configuration ENABLE_CLEAN_EMBEDDING_CACHE_TASK=false ENABLE_CLEAN_UNUSED_DATASETS_TASK=false diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index d3b75d93af..0b9de5fc43 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -566,6 +566,8 @@ x-shared-env: &shared-api-worker-env QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200} QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-} QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30} + SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-true} + SWAGGER_UI_PATH: ${SWAGGER_UI_PATH:-/swagger-ui.html} ENABLE_CLEAN_EMBEDDING_CACHE_TASK: ${ENABLE_CLEAN_EMBEDDING_CACHE_TASK:-false} ENABLE_CLEAN_UNUSED_DATASETS_TASK: ${ENABLE_CLEAN_UNUSED_DATASETS_TASK:-false} ENABLE_CREATE_TIDB_SERVERLESS_TASK: ${ENABLE_CREATE_TIDB_SERVERLESS_TASK:-false} diff --git a/web/Dockerfile b/web/Dockerfile index d284efca87..1376dec749 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -34,7 +34,7 @@ COPY --from=packages /app/web/ . COPY . . ENV NODE_OPTIONS="--max-old-space-size=4096" -RUN pnpm build +RUN pnpm build:docker # production stage diff --git a/web/app/account/account-page/AvatarWithEdit.tsx b/web/app/account/account-page/AvatarWithEdit.tsx index 88e3a7b343..0408d2ee34 100644 --- a/web/app/account/account-page/AvatarWithEdit.tsx +++ b/web/app/account/account-page/AvatarWithEdit.tsx @@ -30,6 +30,8 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { const [isShowDeleteConfirm, setIsShowDeleteConfirm] = useState(false) const [hoverArea, setHoverArea] = useState('left') + const [onAvatarError, setOnAvatarError] = useState(false) + const handleImageInput: OnImageInput = useCallback(async (isCropped: boolean, fileOrTempUrl: string | File, croppedAreaPixels?: Area, fileName?: string) => { setInputImageInfo( isCropped @@ -98,10 +100,15 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { <>
- + setOnAvatarError(x)} />
hoverArea === 'right' ? setIsShowDeleteConfirm(true) : setIsShowAvatarPicker(true)} + onClick={() => { + if (hoverArea === 'right' && !onAvatarError) + setIsShowDeleteConfirm(true) + else + setIsShowAvatarPicker(true) + }} onMouseMove={(e) => { const rect = e.currentTarget.getBoundingClientRect() const x = e.clientX - rect.left @@ -109,12 +116,15 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { setHoverArea(isRight ? 'right' : 'left') }} > - {hoverArea === 'right' ? - - : - - } - + {hoverArea === 'right' && !onAvatarError ? ( + + + + ) : ( + + + + )}
diff --git a/web/app/components/base/app-icon-picker/index.tsx b/web/app/components/base/app-icon-picker/index.tsx index 8e66cd38cf..bc5f09c7a7 100644 --- a/web/app/components/base/app-icon-picker/index.tsx +++ b/web/app/components/base/app-icon-picker/index.tsx @@ -120,7 +120,7 @@ const AppIconPicker: FC = ({ - void onManageMetadata: () => void + statusFilter: Item + onStatusFilterChange: (filter: string) => void } /** @@ -469,6 +443,7 @@ const DocumentList: FC = ({ pagination, onUpdate, onManageMetadata, + statusFilter, }) => { const { t } = useTranslation() const { formatTime } = useTimestamp() @@ -480,6 +455,7 @@ const DocumentList: FC = ({ const [localDocs, setLocalDocs] = useState(documents) const [sortField, setSortField] = useState<'name' | 'word_count' | 'hit_count' | 'created_at' | null>('created_at') const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc') + const { isShowEditModal, showEditModal, @@ -494,12 +470,22 @@ const DocumentList: FC = ({ }) useEffect(() => { + let filteredDocs = documents + + if (statusFilter.value !== 'all') { + filteredDocs = filteredDocs.filter(doc => + typeof doc.display_status === 'string' + && typeof statusFilter.value === 'string' + && doc.display_status.toLowerCase() === statusFilter.value.toLowerCase(), + ) + } + if (!sortField) { - setLocalDocs(documents) + setLocalDocs(filteredDocs) return } - const sortedDocs = [...documents].sort((a, b) => { + const sortedDocs = [...filteredDocs].sort((a, b) => { let aValue: any let bValue: any @@ -535,7 +521,7 @@ const DocumentList: FC = ({ }) setLocalDocs(sortedDocs) - }, [documents, sortField, sortOrder]) + }, [documents, sortField, sortOrder, statusFilter]) const handleSort = (field: 'name' | 'word_count' | 'hit_count' | 'created_at') => { if (sortField === field) { @@ -692,7 +678,11 @@ const DocumentList: FC = ({ {doc?.data_source_type === DataSourceType.FILE && } {doc?.data_source_type === DataSourceType.WEB && } - {doc.name} + + {doc.name} +
-
-
+
+
diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index 1f5ced612c..74f47c9d1d 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -86,6 +86,7 @@ export enum ModelStatusEnum { quotaExceeded = 'quota-exceeded', noPermission = 'no-permission', disabled = 'disabled', + credentialRemoved = 'credential-removed', } export const MODEL_STATUS_TEXT: { [k: string]: TypeWithI18N } = { @@ -153,6 +154,7 @@ export type ModelItem = { model_properties: Record load_balancing_enabled: boolean deprecated?: boolean + has_invalid_load_balancing_configs?: boolean } export enum PreferredProviderTypeEnum { @@ -181,6 +183,29 @@ export type QuotaConfiguration = { is_valid: boolean } +export type Credential = { + credential_id: string + credential_name?: string + from_enterprise?: boolean + not_allowed_to_use?: boolean +} + +export type CustomModel = { + model: string + model_type: ModelTypeEnum +} + +export type CustomModelCredential = CustomModel & { + credentials?: Record + available_model_credentials?: Credential[] + current_credential_id?: string +} + +export type CredentialWithModel = Credential & { + model: string + model_type: ModelTypeEnum +} + export type ModelProvider = { provider: string label: TypeWithI18N @@ -207,12 +232,17 @@ export type ModelProvider = { preferred_provider_type: PreferredProviderTypeEnum custom_configuration: { status: CustomConfigurationStatusEnum + current_credential_id?: string + current_credential_name?: string + available_credentials?: Credential[] + custom_models?: CustomModelCredential[] } system_configuration: { enabled: boolean current_quota_type: CurrentSystemQuotaTypeEnum quota_configurations: QuotaConfiguration[] } + allow_custom_token?: boolean } export type Model = { @@ -272,9 +302,24 @@ export type ModelLoadBalancingConfigEntry = { in_cooldown?: boolean /** cooldown time (in seconds) */ ttl?: number + credential_id?: string } export type ModelLoadBalancingConfig = { enabled: boolean configs: ModelLoadBalancingConfigEntry[] } + +export type ProviderCredential = { + credentials: Record + name: string + credential_id: string +} + +export type ModelCredential = { + credentials: Record + load_balancing: ModelLoadBalancingConfig + available_credentials: Credential[] + current_credential_id?: string + current_credential_name?: string +} diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 48acaeb64a..fa5130137a 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -7,7 +7,9 @@ import { import useSWR, { useSWRConfig } from 'swr' import { useContext } from 'use-context-selector' import type { + Credential, CustomConfigurationModelFixedFields, + CustomModel, DefaultModel, DefaultModelResponse, Model, @@ -77,16 +79,17 @@ export const useProviderCredentialsAndLoadBalancing = ( configurationMethod: ConfigurationMethodEnum, configured?: boolean, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, + credentialId?: string, ) => { - const { data: predefinedFormSchemasValue, mutate: mutatePredefined } = useSWR( - (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured) - ? `/workspaces/current/model-providers/${provider}/credentials` + const { data: predefinedFormSchemasValue, mutate: mutatePredefined, isLoading: isPredefinedLoading } = useSWR( + (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured && credentialId) + ? `/workspaces/current/model-providers/${provider}/credentials${credentialId ? `?credential_id=${credentialId}` : ''}` : null, fetchModelProviderCredentials, ) - const { data: customFormSchemasValue, mutate: mutateCustomized } = useSWR( - (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields) - ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}` + const { data: customFormSchemasValue, mutate: mutateCustomized, isLoading: isCustomizedLoading } = useSWR( + (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields && credentialId) + ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}${credentialId ? `&credential_id=${credentialId}` : ''}` : null, fetchModelProviderCredentials, ) @@ -102,6 +105,7 @@ export const useProviderCredentialsAndLoadBalancing = ( : undefined }, [ configurationMethod, + credentialId, currentCustomConfigurationModelFixedFields, customFormSchemasValue?.credentials, predefinedFormSchemasValue?.credentials, @@ -119,6 +123,7 @@ export const useProviderCredentialsAndLoadBalancing = ( : customFormSchemasValue )?.load_balancing, mutate, + isLoading: isPredefinedLoading || isCustomizedLoading, } // as ([Record | undefined, ModelLoadBalancingConfig | undefined]) } @@ -313,40 +318,59 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: } } -export const useModelModalHandler = () => { - const setShowModelModal = useModalContextSelector(state => state.setShowModelModal) +export const useRefreshModel = () => { + const { eventEmitter } = useEventEmitterContextContext() const updateModelProviders = useUpdateModelProviders() const updateModelList = useUpdateModelList() - const { eventEmitter } = useEventEmitterContextContext() + const handleRefreshModel = useCallback((provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => { + updateModelProviders() + + provider.supported_model_types.forEach((type) => { + updateModelList(type) + }) + + if (configurationMethod === ConfigurationMethodEnum.customizableModel + && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) { + eventEmitter?.emit({ + type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST, + payload: provider.provider, + } as any) + + if (CustomConfigurationModelFixedFields?.__model_type) + updateModelList(CustomConfigurationModelFixedFields.__model_type) + } + }, [eventEmitter, updateModelList, updateModelProviders]) + + return { + handleRefreshModel, + } +} + +export const useModelModalHandler = () => { + const setShowModelModal = useModalContextSelector(state => state.setShowModelModal) + const { handleRefreshModel } = useRefreshModel() return ( provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, + isModelCredential?: boolean, + credential?: Credential, + model?: CustomModel, + onUpdate?: () => void, ) => { setShowModelModal({ payload: { currentProvider: provider, currentConfigurationMethod: configurationMethod, currentCustomConfigurationModelFixedFields: CustomConfigurationModelFixedFields, + isModelCredential, + credential, + model, }, onSaveCallback: () => { - updateModelProviders() - - provider.supported_model_types.forEach((type) => { - updateModelList(type) - }) - - if (configurationMethod === ConfigurationMethodEnum.customizableModel - && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) { - eventEmitter?.emit({ - type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST, - payload: provider.provider, - } as any) - - if (CustomConfigurationModelFixedFields?.__model_type) - updateModelList(CustomConfigurationModelFixedFields.__model_type) - } + handleRefreshModel(provider, configurationMethod, CustomConfigurationModelFixedFields) + onUpdate?.() }, }) } diff --git a/web/app/components/header/account-setting/model-provider-page/index.tsx b/web/app/components/header/account-setting/model-provider-page/index.tsx index 4aa98daf66..35de29185f 100644 --- a/web/app/components/header/account-setting/model-provider-page/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/index.tsx @@ -8,8 +8,6 @@ import { import SystemModelSelector from './system-model-selector' import ProviderAddedCard from './provider-added-card' import type { - ConfigurationMethodEnum, - CustomConfigurationModelFixedFields, ModelProvider, } from './declarations' import { @@ -18,7 +16,6 @@ import { } from './declarations' import { useDefaultModel, - useModelModalHandler, } from './hooks' import InstallFromMarketplace from './install-from-marketplace' import { useProviderContext } from '@/context/provider-context' @@ -84,8 +81,6 @@ const ModelProviderPage = ({ searchText }: Props) => { return [filteredConfiguredProviders, filteredNotConfiguredProviders] }, [configuredProviders, debouncedSearchText, notConfiguredProviders]) - const handleOpenModal = useModelModalHandler() - return (
@@ -126,7 +121,6 @@ const ModelProviderPage = ({ searchText }: Props) => { handleOpenModal(provider, configurationMethod, currentCustomConfigurationModelFixedFields)} /> ))}
@@ -140,7 +134,6 @@ const ModelProviderPage = ({ searchText }: Props) => { notConfigured key={provider.provider} provider={provider} - onOpenModal={(configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => handleOpenModal(provider, configurationMethod, currentCustomConfigurationModelFixedFields)} /> ))}
diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/add-credential-in-load-balancing.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/add-credential-in-load-balancing.tsx new file mode 100644 index 0000000000..64e631614d --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/add-credential-in-load-balancing.tsx @@ -0,0 +1,115 @@ +import { + memo, + useCallback, + useMemo, +} from 'react' +import { RiAddLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { Authorized } from '@/app/components/header/account-setting/model-provider-page/model-auth' +import cn from '@/utils/classnames' +import type { + Credential, + CustomModelCredential, + ModelCredential, + ModelProvider, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import Tooltip from '@/app/components/base/tooltip' + +type AddCredentialInLoadBalancingProps = { + provider: ModelProvider + model: CustomModelCredential + configurationMethod: ConfigurationMethodEnum + modelCredential: ModelCredential + onSelectCredential: (credential: Credential) => void + onUpdate?: () => void +} +const AddCredentialInLoadBalancing = ({ + provider, + model, + configurationMethod, + modelCredential, + onSelectCredential, + onUpdate, +}: AddCredentialInLoadBalancingProps) => { + const { t } = useTranslation() + const { + available_credentials, + } = modelCredential + const customModel = configurationMethod === ConfigurationMethodEnum.customizableModel + const notAllowCustomCredential = provider.allow_custom_token === false + + const ButtonComponent = useMemo(() => { + const Item = ( +
+ + { + customModel + ? t('common.modelProvider.auth.addCredential') + : t('common.modelProvider.auth.addApiKey') + } +
+ ) + + if (notAllowCustomCredential) { + return ( + + {Item} + + ) + } + return Item + }, [notAllowCustomCredential, t, customModel]) + + const renderTrigger = useCallback((open?: boolean) => { + const Item = ( +
+ + { + customModel + ? t('common.modelProvider.auth.addCredential') + : t('common.modelProvider.auth.addApiKey') + } +
+ ) + + return Item + }, [t, customModel]) + + if (!available_credentials?.length) + return ButtonComponent + + return ( + + ) +} + +export default memo(AddCredentialInLoadBalancing) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx new file mode 100644 index 0000000000..0ec6fa45a0 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx @@ -0,0 +1,111 @@ +import { + memo, + useCallback, + useMemo, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + RiAddCircleFill, +} from '@remixicon/react' +import { + Button, +} from '@/app/components/base/button' +import type { + CustomConfigurationModelFixedFields, + ModelProvider, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import Authorized from './authorized' +import { + useAuth, + useCustomModels, +} from './hooks' +import cn from '@/utils/classnames' +import Tooltip from '@/app/components/base/tooltip' + +type AddCustomModelProps = { + provider: ModelProvider, + configurationMethod: ConfigurationMethodEnum, + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, +} +const AddCustomModel = ({ + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, +}: AddCustomModelProps) => { + const { t } = useTranslation() + const customModels = useCustomModels(provider) + const noModels = !customModels.length + const { + handleOpenModal, + } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, true) + const notAllowCustomCredential = provider.allow_custom_token === false + const handleClick = useCallback(() => { + if (notAllowCustomCredential) + return + + handleOpenModal() + }, [handleOpenModal, notAllowCustomCredential]) + const ButtonComponent = useMemo(() => { + const Item = ( + + ) + if (notAllowCustomCredential) { + return ( + + {Item} + + ) + } + return Item + }, [handleClick, notAllowCustomCredential, t]) + + const renderTrigger = useCallback((open?: boolean) => { + const Item = ( + + ) + return Item + }, [t]) + + if (noModels) + return ButtonComponent + + return ( + ({ + model, + credentials: model.available_model_credentials ?? [], + }))} + renderTrigger={renderTrigger} + isModelCredential + enableAddModelCredential + bottomAddModelCredentialText={t('common.modelProvider.auth.addNewModel')} + /> + ) +} + +export default memo(AddCustomModel) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/authorized-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/authorized-item.tsx new file mode 100644 index 0000000000..4f4c30bc9b --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/authorized-item.tsx @@ -0,0 +1,101 @@ +import { + memo, + useCallback, +} from 'react' +import { RiAddLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import CredentialItem from './credential-item' +import type { + Credential, + CustomModel, + CustomModelCredential, +} from '../../declarations' +import Button from '@/app/components/base/button' +import Tooltip from '@/app/components/base/tooltip' + +type AuthorizedItemProps = { + model?: CustomModelCredential + title?: string + disabled?: boolean + onDelete?: (credential?: Credential, model?: CustomModel) => void + onEdit?: (credential?: Credential, model?: CustomModel) => void + showItemSelectedIcon?: boolean + selectedCredentialId?: string + credentials: Credential[] + onItemClick?: (credential: Credential, model?: CustomModel) => void + enableAddModelCredential?: boolean + notAllowCustomCredential?: boolean +} +export const AuthorizedItem = ({ + model, + title, + credentials, + disabled, + onDelete, + onEdit, + showItemSelectedIcon, + selectedCredentialId, + onItemClick, + enableAddModelCredential, + notAllowCustomCredential, +}: AuthorizedItemProps) => { + const { t } = useTranslation() + const handleEdit = useCallback((credential?: Credential) => { + onEdit?.(credential, model) + }, [onEdit, model]) + const handleDelete = useCallback((credential?: Credential) => { + onDelete?.(credential, model) + }, [onDelete, model]) + const handleItemClick = useCallback((credential: Credential) => { + onItemClick?.(credential, model) + }, [onItemClick, model]) + + return ( +
+
+
+
+ {title ?? model?.model} +
+ { + enableAddModelCredential && !notAllowCustomCredential && ( + + + + ) + } +
+ { + credentials.map(credential => ( + + )) + } +
+ ) +} + +export default memo(AuthorizedItem) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/credential-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/credential-item.tsx new file mode 100644 index 0000000000..6596e64e0d --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/credential-item.tsx @@ -0,0 +1,137 @@ +import { + memo, + useMemo, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + RiCheckLine, + RiDeleteBinLine, + RiEqualizer2Line, +} from '@remixicon/react' +import Indicator from '@/app/components/header/indicator' +import ActionButton from '@/app/components/base/action-button' +import Tooltip from '@/app/components/base/tooltip' +import cn from '@/utils/classnames' +import type { Credential } from '../../declarations' +import Badge from '@/app/components/base/badge' + +type CredentialItemProps = { + credential: Credential + disabled?: boolean + onDelete?: (credential: Credential) => void + onEdit?: (credential?: Credential) => void + onItemClick?: (credential: Credential) => void + disableRename?: boolean + disableEdit?: boolean + disableDelete?: boolean + showSelectedIcon?: boolean + selectedCredentialId?: string +} +const CredentialItem = ({ + credential, + disabled, + onDelete, + onEdit, + onItemClick, + disableRename, + disableEdit, + disableDelete, + showSelectedIcon, + selectedCredentialId, +}: CredentialItemProps) => { + const { t } = useTranslation() + const showAction = useMemo(() => { + return !(disableRename && disableEdit && disableDelete) + }, [disableRename, disableEdit, disableDelete]) + + const Item = ( +
{ + if (disabled || credential.not_allowed_to_use) + return + onItemClick?.(credential) + }} + > +
+ { + showSelectedIcon && ( +
+ { + selectedCredentialId === credential.credential_id && ( + + ) + } +
+ ) + } + +
+ {credential.credential_name} +
+
+ { + credential.from_enterprise && ( + + Enterprise + + ) + } + { + showAction && ( +
+ { + !disableEdit && !credential.not_allowed_to_use && !credential.from_enterprise && ( + + { + e.stopPropagation() + onEdit?.(credential) + }} + > + + + + ) + } + { + !disableDelete && !credential.from_enterprise && ( + + { + e.stopPropagation() + onDelete?.(credential) + }} + > + + + + ) + } +
+ ) + } +
+ ) + + if (credential.not_allowed_to_use) { + return ( + + {Item} + + ) + } + return Item +} + +export default memo(CredentialItem) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx new file mode 100644 index 0000000000..3e7c04a0f2 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx @@ -0,0 +1,222 @@ +import { + memo, + useCallback, + useMemo, + useState, +} from 'react' +import { + RiAddLine, + RiEqualizer2Line, +} from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import type { + PortalToFollowElemOptions, +} from '@/app/components/base/portal-to-follow-elem' +import Button from '@/app/components/base/button' +import cn from '@/utils/classnames' +import Confirm from '@/app/components/base/confirm' +import type { + ConfigurationMethodEnum, + Credential, + CustomConfigurationModelFixedFields, + CustomModel, + ModelProvider, +} from '../../declarations' +import { useAuth } from '../hooks' +import AuthorizedItem from './authorized-item' + +type AuthorizedProps = { + provider: ModelProvider, + configurationMethod: ConfigurationMethodEnum, + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, + isModelCredential?: boolean + items: { + title?: string + model?: CustomModel + credentials: Credential[] + }[] + selectedCredential?: Credential + disabled?: boolean + renderTrigger?: (open?: boolean) => React.ReactNode + isOpen?: boolean + onOpenChange?: (open: boolean) => void + offset?: PortalToFollowElemOptions['offset'] + placement?: PortalToFollowElemOptions['placement'] + triggerPopupSameWidth?: boolean + popupClassName?: string + showItemSelectedIcon?: boolean + onUpdate?: () => void + onItemClick?: (credential: Credential, model?: CustomModel) => void + enableAddModelCredential?: boolean + bottomAddModelCredentialText?: string +} +const Authorized = ({ + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, + items, + isModelCredential, + selectedCredential, + disabled, + renderTrigger, + isOpen, + onOpenChange, + offset = 8, + placement = 'bottom-end', + triggerPopupSameWidth = false, + popupClassName, + showItemSelectedIcon, + onUpdate, + onItemClick, + enableAddModelCredential, + bottomAddModelCredentialText, +}: AuthorizedProps) => { + const { t } = useTranslation() + const [isLocalOpen, setIsLocalOpen] = useState(false) + const mergedIsOpen = isOpen ?? isLocalOpen + const setMergedIsOpen = useCallback((open: boolean) => { + if (onOpenChange) + onOpenChange(open) + + setIsLocalOpen(open) + }, [onOpenChange]) + const { + openConfirmDelete, + closeConfirmDelete, + doingAction, + handleActiveCredential, + handleConfirmDelete, + deleteCredentialId, + handleOpenModal, + } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate) + + const handleEdit = useCallback((credential?: Credential, model?: CustomModel) => { + handleOpenModal(credential, model) + setMergedIsOpen(false) + }, [handleOpenModal, setMergedIsOpen]) + + const handleItemClick = useCallback((credential: Credential, model?: CustomModel) => { + if (onItemClick) + onItemClick(credential, model) + else + handleActiveCredential(credential, model) + + setMergedIsOpen(false) + }, [handleActiveCredential, onItemClick, setMergedIsOpen]) + const notAllowCustomCredential = provider.allow_custom_token === false + + const Trigger = useMemo(() => { + const Item = ( + + ) + return Item + }, [t]) + + return ( + <> + + { + setMergedIsOpen(!mergedIsOpen) + }} + asChild + > + { + renderTrigger + ? renderTrigger(mergedIsOpen) + : Trigger + } + + +
+
+ { + items.map((item, index) => ( + + )) + } +
+
+ { + isModelCredential && !notAllowCustomCredential && ( +
handleEdit( + undefined, + currentCustomConfigurationModelFixedFields + ? { + model: currentCustomConfigurationModelFixedFields.__model_name, + model_type: currentCustomConfigurationModelFixedFields.__model_type, + } + : undefined, + )} + className='system-xs-medium flex h-[30px] cursor-pointer items-center px-3 text-text-accent-light-mode-only' + > + + {bottomAddModelCredentialText ?? t('common.modelProvider.auth.addModelCredential')} +
+ ) + } + { + !isModelCredential && !notAllowCustomCredential && ( +
+ +
+ ) + } +
+
+
+ { + deleteCredentialId && ( + + ) + } + + ) +} + +export default memo(Authorized) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/config-model.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/config-model.tsx new file mode 100644 index 0000000000..02d9eb2742 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/config-model.tsx @@ -0,0 +1,76 @@ +import { memo } from 'react' +import { + RiEqualizer2Line, + RiScales3Line, +} from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' +import Indicator from '@/app/components/header/indicator' +import cn from '@/utils/classnames' + +type ConfigModelProps = { + onClick?: () => void + loadBalancingEnabled?: boolean + loadBalancingInvalid?: boolean + credentialRemoved?: boolean +} +const ConfigModel = ({ + onClick, + loadBalancingEnabled, + loadBalancingInvalid, + credentialRemoved, +}: ConfigModelProps) => { + const { t } = useTranslation() + + if (loadBalancingInvalid) { + return ( +
+ + {t('common.modelProvider.auth.authorizationError')} + +
+ ) + } + + return ( + + ) +} + +export default memo(ConfigModel) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/config-provider.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/config-provider.tsx new file mode 100644 index 0000000000..ba9049a83e --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/config-provider.tsx @@ -0,0 +1,96 @@ +import { + memo, + useCallback, + useMemo, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + RiEqualizer2Line, +} from '@remixicon/react' +import { + Button, +} from '@/app/components/base/button' +import type { + CustomConfigurationModelFixedFields, + ModelProvider, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import Authorized from './authorized' +import { useAuth, useCredentialStatus } from './hooks' +import Tooltip from '@/app/components/base/tooltip' +import cn from '@/utils/classnames' + +type ConfigProviderProps = { + provider: ModelProvider, + configurationMethod: ConfigurationMethodEnum, + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, +} +const ConfigProvider = ({ + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, +}: ConfigProviderProps) => { + const { t } = useTranslation() + const { + handleOpenModal, + } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields) + const { + hasCredential, + authorized, + current_credential_id, + current_credential_name, + available_credentials, + } = useCredentialStatus(provider) + const notAllowCustomCredential = provider.allow_custom_token === false + const handleClick = useCallback(() => { + if (!hasCredential && !notAllowCustomCredential) + handleOpenModal() + }, [handleOpenModal, hasCredential, notAllowCustomCredential]) + const ButtonComponent = useMemo(() => { + const Item = ( + + ) + if (notAllowCustomCredential) { + return ( + + {Item} + + ) + } + return Item + }, [handleClick, authorized, notAllowCustomCredential, t]) + + if (!hasCredential) + return ButtonComponent + + return ( + + ) +} + +export default memo(ConfigProvider) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts new file mode 100644 index 0000000000..fd0bee512f --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts @@ -0,0 +1,6 @@ +export * from './use-model-form-schemas' +export * from './use-credential-status' +export * from './use-custom-models' +export * from './use-auth' +export * from './use-auth-service' +export * from './use-credential-data' diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth-service.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth-service.ts new file mode 100644 index 0000000000..317a1fe1a9 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth-service.ts @@ -0,0 +1,57 @@ +import { useCallback } from 'react' +import { + useActiveModelCredential, + useActiveProviderCredential, + useAddModelCredential, + useAddProviderCredential, + useDeleteModelCredential, + useDeleteProviderCredential, + useEditModelCredential, + useEditProviderCredential, + useGetModelCredential, + useGetProviderCredential, +} from '@/service/use-models' +import type { + CustomModel, +} from '@/app/components/header/account-setting/model-provider-page/declarations' + +export const useGetCredential = (provider: string, isModelCredential?: boolean, credentialId?: string, model?: CustomModel, configFrom?: string) => { + const providerData = useGetProviderCredential(!isModelCredential && !!credentialId, provider, credentialId) + const modelData = useGetModelCredential(!!isModelCredential && !!credentialId, provider, credentialId, model?.model, model?.model_type, configFrom) + return isModelCredential ? modelData : providerData +} + +export const useAuthService = (provider: string) => { + const { mutateAsync: addProviderCredential } = useAddProviderCredential(provider) + const { mutateAsync: editProviderCredential } = useEditProviderCredential(provider) + const { mutateAsync: deleteProviderCredential } = useDeleteProviderCredential(provider) + const { mutateAsync: activeProviderCredential } = useActiveProviderCredential(provider) + + const { mutateAsync: addModelCredential } = useAddModelCredential(provider) + const { mutateAsync: activeModelCredential } = useActiveModelCredential(provider) + const { mutateAsync: deleteModelCredential } = useDeleteModelCredential(provider) + const { mutateAsync: editModelCredential } = useEditModelCredential(provider) + + const getAddCredentialService = useCallback((isModel: boolean) => { + return isModel ? addModelCredential : addProviderCredential + }, [addModelCredential, addProviderCredential]) + + const getEditCredentialService = useCallback((isModel: boolean) => { + return isModel ? editModelCredential : editProviderCredential + }, [editModelCredential, editProviderCredential]) + + const getDeleteCredentialService = useCallback((isModel: boolean) => { + return isModel ? deleteModelCredential : deleteProviderCredential + }, [deleteModelCredential, deleteProviderCredential]) + + const getActiveCredentialService = useCallback((isModel: boolean) => { + return isModel ? activeModelCredential : activeProviderCredential + }, [activeModelCredential, activeProviderCredential]) + + return { + getAddCredentialService, + getEditCredentialService, + getDeleteCredentialService, + getActiveCredentialService, + } +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts new file mode 100644 index 0000000000..d4a0417a44 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts @@ -0,0 +1,158 @@ +import { + useCallback, + useRef, + useState, +} from 'react' +import { useTranslation } from 'react-i18next' +import { useToastContext } from '@/app/components/base/toast' +import { useAuthService } from './use-auth-service' +import type { + ConfigurationMethodEnum, + Credential, + CustomConfigurationModelFixedFields, + CustomModel, + ModelProvider, +} from '../../declarations' +import { + useModelModalHandler, + useRefreshModel, +} from '@/app/components/header/account-setting/model-provider-page/hooks' + +export const useAuth = ( + provider: ModelProvider, + configurationMethod: ConfigurationMethodEnum, + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, + isModelCredential?: boolean, + onUpdate?: () => void, +) => { + const { t } = useTranslation() + const { notify } = useToastContext() + const { + getDeleteCredentialService, + getActiveCredentialService, + getEditCredentialService, + getAddCredentialService, + } = useAuthService(provider.provider) + const handleOpenModelModal = useModelModalHandler() + const { handleRefreshModel } = useRefreshModel() + const pendingOperationCredentialId = useRef(null) + const pendingOperationModel = useRef(null) + const [deleteCredentialId, setDeleteCredentialId] = useState(null) + const openConfirmDelete = useCallback((credential?: Credential, model?: CustomModel) => { + if (credential) + pendingOperationCredentialId.current = credential.credential_id + if (model) + pendingOperationModel.current = model + + setDeleteCredentialId(pendingOperationCredentialId.current) + }, []) + const closeConfirmDelete = useCallback(() => { + setDeleteCredentialId(null) + pendingOperationCredentialId.current = null + }, []) + const [doingAction, setDoingAction] = useState(false) + const doingActionRef = useRef(doingAction) + const handleSetDoingAction = useCallback((doing: boolean) => { + doingActionRef.current = doing + setDoingAction(doing) + }, []) + const handleActiveCredential = useCallback(async (credential: Credential, model?: CustomModel) => { + if (doingActionRef.current) + return + try { + handleSetDoingAction(true) + await getActiveCredentialService(!!model)({ + credential_id: credential.credential_id, + model: model?.model, + model_type: model?.model_type, + }) + notify({ + type: 'success', + message: t('common.api.actionSuccess'), + }) + onUpdate?.() + handleRefreshModel(provider, configurationMethod, undefined) + } + finally { + handleSetDoingAction(false) + } + }, [getActiveCredentialService, onUpdate, notify, t, handleSetDoingAction]) + const handleConfirmDelete = useCallback(async () => { + if (doingActionRef.current) + return + if (!pendingOperationCredentialId.current) { + setDeleteCredentialId(null) + return + } + try { + handleSetDoingAction(true) + await getDeleteCredentialService(!!isModelCredential)({ + credential_id: pendingOperationCredentialId.current, + model: pendingOperationModel.current?.model, + model_type: pendingOperationModel.current?.model_type, + }) + notify({ + type: 'success', + message: t('common.api.actionSuccess'), + }) + onUpdate?.() + handleRefreshModel(provider, configurationMethod, undefined) + setDeleteCredentialId(null) + pendingOperationCredentialId.current = null + pendingOperationModel.current = null + } + finally { + handleSetDoingAction(false) + } + }, [onUpdate, notify, t, handleSetDoingAction, getDeleteCredentialService, isModelCredential]) + const handleAddCredential = useCallback((model?: CustomModel) => { + if (model) + pendingOperationModel.current = model + }, []) + const handleSaveCredential = useCallback(async (payload: Record) => { + if (doingActionRef.current) + return + try { + handleSetDoingAction(true) + + let res: { result?: string } = {} + if (payload.credential_id) + res = await getEditCredentialService(!!isModelCredential)(payload as any) + else + res = await getAddCredentialService(!!isModelCredential)(payload as any) + + if (res.result === 'success') { + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + onUpdate?.() + } + } + finally { + handleSetDoingAction(false) + } + }, [onUpdate, notify, t, handleSetDoingAction, getEditCredentialService, getAddCredentialService]) + const handleOpenModal = useCallback((credential?: Credential, model?: CustomModel) => { + handleOpenModelModal( + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, + isModelCredential, + credential, + model, + onUpdate, + ) + }, [handleOpenModelModal, provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate]) + + return { + pendingOperationCredentialId, + pendingOperationModel, + openConfirmDelete, + closeConfirmDelete, + doingAction, + handleActiveCredential, + handleConfirmDelete, + handleAddCredential, + deleteCredentialId, + handleSaveCredential, + handleOpenModal, + } +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-data.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-data.ts new file mode 100644 index 0000000000..2fbc8b1033 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-data.ts @@ -0,0 +1,24 @@ +import { useMemo } from 'react' +import { useGetCredential } from './use-auth-service' +import type { + Credential, + CustomModelCredential, + ModelProvider, +} from '@/app/components/header/account-setting/model-provider-page/declarations' + +export const useCredentialData = (provider: ModelProvider, providerFormSchemaPredefined: boolean, isModelCredential?: boolean, credential?: Credential, model?: CustomModelCredential) => { + const configFrom = useMemo(() => { + if (providerFormSchemaPredefined) + return 'predefined-model' + return 'custom-model' + }, [providerFormSchemaPredefined]) + const { + isLoading, + data: credentialData = {}, + } = useGetCredential(provider.provider, isModelCredential, credential?.credential_id, model, configFrom) + + return { + isLoading, + credentialData, + } +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-status.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-status.ts new file mode 100644 index 0000000000..3fa3877b3f --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-status.ts @@ -0,0 +1,26 @@ +import { useMemo } from 'react' +import type { + ModelProvider, +} from '../../declarations' + +export const useCredentialStatus = (provider: ModelProvider) => { + const { + current_credential_id, + current_credential_name, + available_credentials, + } = provider.custom_configuration + const hasCredential = !!available_credentials?.length + const authorized = current_credential_id && current_credential_name + const authRemoved = hasCredential && !current_credential_id && !current_credential_name + const currentCredential = available_credentials?.find(credential => credential.credential_id === current_credential_id) + + return useMemo(() => ({ + hasCredential, + authorized, + authRemoved, + current_credential_id, + current_credential_name, + available_credentials, + notAllowedToUse: currentCredential?.not_allowed_to_use, + }), [hasCredential, authorized, authRemoved, current_credential_id, current_credential_name, available_credentials]) +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-custom-models.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-custom-models.ts new file mode 100644 index 0000000000..f3b50f3f49 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-custom-models.ts @@ -0,0 +1,9 @@ +import type { + ModelProvider, +} from '../../declarations' + +export const useCustomModels = (provider: ModelProvider) => { + const { custom_models } = provider.custom_configuration + + return custom_models || [] +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-model-form-schemas.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-model-form-schemas.ts new file mode 100644 index 0000000000..eafbedfddf --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-model-form-schemas.ts @@ -0,0 +1,83 @@ +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import type { + Credential, + CustomModelCredential, + ModelLoadBalancingConfig, + ModelProvider, +} from '../../declarations' +import { + genModelNameFormSchema, + genModelTypeFormSchema, +} from '../../utils' +import { FormTypeEnum } from '@/app/components/base/form/types' + +export const useModelFormSchemas = ( + provider: ModelProvider, + providerFormSchemaPredefined: boolean, + credentials?: Record, + credential?: Credential, + model?: CustomModelCredential, + draftConfig?: ModelLoadBalancingConfig, +) => { + const { t } = useTranslation() + const { + provider_credential_schema, + supported_model_types, + model_credential_schema, + } = provider + const formSchemas = useMemo(() => { + const modelTypeSchema = genModelTypeFormSchema(supported_model_types) + const modelNameSchema = genModelNameFormSchema(model_credential_schema?.model) + if (!!model) { + modelTypeSchema.disabled = true + modelNameSchema.disabled = true + } + return providerFormSchemaPredefined + ? provider_credential_schema.credential_form_schemas + : [ + modelTypeSchema, + modelNameSchema, + ...(draftConfig?.enabled ? [] : model_credential_schema.credential_form_schemas), + ] + }, [ + providerFormSchemaPredefined, + provider_credential_schema?.credential_form_schemas, + supported_model_types, + model_credential_schema?.credential_form_schemas, + model_credential_schema?.model, + draftConfig?.enabled, + model, + ]) + + const formSchemasWithAuthorizationName = useMemo(() => { + const authorizationNameSchema = { + type: FormTypeEnum.textInput, + variable: '__authorization_name__', + label: t('plugin.auth.authorizationName'), + required: true, + } + + return [ + authorizationNameSchema, + ...formSchemas, + ] + }, [formSchemas, t]) + + const formValues = useMemo(() => { + let result = {} + if (credential) { + result = { ...result, __authorization_name__: credential?.credential_name } + if (credentials) + result = { ...result, ...credentials } + } + if (model) + result = { ...result, __model_name: model?.model, __model_type: model?.model_type } + return result + }, [credentials, credential, model]) + + return { + formSchemas: formSchemasWithAuthorizationName, + formValues, + } +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/index.tsx new file mode 100644 index 0000000000..05effcea7c --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/index.tsx @@ -0,0 +1,6 @@ +export { default as Authorized } from './authorized' +export { default as SwitchCredentialInLoadBalancing } from './switch-credential-in-load-balancing' +export { default as AddCredentialInLoadBalancing } from './add-credential-in-load-balancing' +export { default as AddCustomModel } from './add-custom-model' +export { default as ConfigProvider } from './config-provider' +export { default as ConfigModel } from './config-model' diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx new file mode 100644 index 0000000000..8f81107bb2 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx @@ -0,0 +1,122 @@ +import type { Dispatch, SetStateAction } from 'react' +import { + memo, + useCallback, +} from 'react' +import { useTranslation } from 'react-i18next' +import { RiArrowDownSLine } from '@remixicon/react' +import Button from '@/app/components/base/button' +import Indicator from '@/app/components/header/indicator' +import Authorized from './authorized' +import type { + Credential, + CustomModel, + ModelProvider, +} from '../declarations' +import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import cn from '@/utils/classnames' +import Tooltip from '@/app/components/base/tooltip' +import Badge from '@/app/components/base/badge' + +type SwitchCredentialInLoadBalancingProps = { + provider: ModelProvider + model: CustomModel + credentials?: Credential[] + customModelCredential?: Credential + setCustomModelCredential: Dispatch> +} +const SwitchCredentialInLoadBalancing = ({ + provider, + model, + customModelCredential, + setCustomModelCredential, + credentials, +}: SwitchCredentialInLoadBalancingProps) => { + const { t } = useTranslation() + + const handleItemClick = useCallback((credential: Credential) => { + setCustomModelCredential(credential) + }, [setCustomModelCredential]) + + const renderTrigger = useCallback(() => { + const selectedCredentialId = customModelCredential?.credential_id + const authRemoved = !selectedCredentialId && !!credentials?.length + let color = 'green' + if (authRemoved && !customModelCredential?.not_allowed_to_use) + color = 'red' + if (customModelCredential?.not_allowed_to_use) + color = 'gray' + + const Item = ( + + ) + if (customModelCredential?.not_allowed_to_use) { + return ( + + {Item} + + ) + } + return Item + }, [customModelCredential, t, credentials]) + + return ( + + ) +} + +export default memo(SwitchCredentialInLoadBalancing) diff --git a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx index f6fb1dc6f6..02c7c404ab 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx @@ -13,12 +13,14 @@ type ModelIconProps = { provider?: Model | ModelProvider modelName?: string className?: string + iconClassName?: string isDeprecated?: boolean } const ModelIcon: FC = ({ provider, className, modelName, + iconClassName, isDeprecated = false, }) => { const language = useLanguage() @@ -34,7 +36,7 @@ const ModelIcon: FC = ({ if (provider?.icon_small) { return (
- model-icon + model-icon
) } @@ -44,7 +46,7 @@ const ModelIcon: FC = ({ 'flex h-5 w-5 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-subtle', className, )}> -
+
diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx index bc98081dfa..e9050e4837 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx @@ -2,43 +2,22 @@ import type { FC } from 'react' import { memo, useCallback, - useEffect, useMemo, - useState, + useRef, } from 'react' +import { RiCloseLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' -import { - RiErrorWarningFill, -} from '@remixicon/react' import type { - CredentialFormSchema, - CredentialFormSchemaRadio, - CredentialFormSchemaSelect, CustomConfigurationModelFixedFields, - FormValue, - ModelLoadBalancingConfig, - ModelLoadBalancingConfigEntry, ModelProvider, } from '../declarations' import { ConfigurationMethodEnum, - CustomConfigurationStatusEnum, FormTypeEnum, } from '../declarations' -import { - genModelNameFormSchema, - genModelTypeFormSchema, - removeCredentials, - saveCredentials, -} from '../utils' import { useLanguage, - useProviderCredentialsAndLoadBalancing, } from '../hooks' -import { useValidate } from '../../key-validator/hooks' -import { ValidatedStatus } from '../../key-validator/declarations' -import ModelLoadBalancingConfigs from '../provider-added-card/model-load-balancing-configs' -import Form from './Form' import Button from '@/app/components/base/button' import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' @@ -46,9 +25,26 @@ import { PortalToFollowElem, PortalToFollowElemContent, } from '@/app/components/base/portal-to-follow-elem' -import { useToastContext } from '@/app/components/base/toast' import Confirm from '@/app/components/base/confirm' import { useAppContext } from '@/context/app-context' +import AuthForm from '@/app/components/base/form/form-scenarios/auth' +import type { + FormRefObject, + FormSchema, +} from '@/app/components/base/form/types' +import { useModelFormSchemas } from '../model-auth/hooks' +import type { + Credential, + CustomModel, +} from '../declarations' +import Loading from '@/app/components/base/loading' +import { + useAuth, + useCredentialData, +} from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks' +import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' +import Badge from '@/app/components/base/badge' +import { useRenderI18nObject } from '@/hooks/use-i18n' type ModelModalProps = { provider: ModelProvider @@ -56,6 +52,9 @@ type ModelModalProps = { currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields onCancel: () => void onSave: () => void + model?: CustomModel + credential?: Credential + isModelCredential?: boolean } const ModelModal: FC = ({ @@ -64,244 +63,173 @@ const ModelModal: FC = ({ currentCustomConfigurationModelFixedFields, onCancel, onSave, + model, + credential, + isModelCredential, }) => { + const renderI18nObject = useRenderI18nObject() const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel + const { + isLoading, + credentialData, + } = useCredentialData(provider, providerFormSchemaPredefined, isModelCredential, credential, model) + const { + handleSaveCredential, + handleConfirmDelete, + deleteCredentialId, + closeConfirmDelete, + openConfirmDelete, + doingAction, + } = useAuth(provider, configurateMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onSave) const { credentials: formSchemasValue, - loadBalancing: originalConfig, - mutate, - } = useProviderCredentialsAndLoadBalancing( - provider.provider, - configurateMethod, - providerFormSchemaPredefined && provider.custom_configuration.status === CustomConfigurationStatusEnum.active, - currentCustomConfigurationModelFixedFields, - ) + } = credentialData as any + const { isCurrentWorkspaceManager } = useAppContext() const isEditMode = !!formSchemasValue && isCurrentWorkspaceManager const { t } = useTranslation() - const { notify } = useToastContext() const language = useLanguage() - const [loading, setLoading] = useState(false) - const [showConfirm, setShowConfirm] = useState(false) + const { + formSchemas, + formValues, + } = useModelFormSchemas(provider, providerFormSchemaPredefined, formSchemasValue, credential, model) + const formRef = useRef(null) - const [draftConfig, setDraftConfig] = useState() - const originalConfigMap = useMemo(() => { - if (!originalConfig) - return {} - return originalConfig?.configs.reduce((prev, config) => { - if (config.id) - prev[config.id] = config - return prev - }, {} as Record) - }, [originalConfig]) - useEffect(() => { - if (originalConfig && !draftConfig) - setDraftConfig(originalConfig) - }, [draftConfig, originalConfig]) + const handleSave = useCallback(async () => { + const { + isCheckValidated, + values, + } = formRef.current?.getFormValues({ + needCheckValidatedValues: true, + needTransformWhenSecretFieldIsPristine: true, + }) || { isCheckValidated: false, values: {} } + if (!isCheckValidated) + return - const formSchemas = useMemo(() => { - return providerFormSchemaPredefined - ? provider.provider_credential_schema.credential_form_schemas - : [ - genModelTypeFormSchema(provider.supported_model_types), - genModelNameFormSchema(provider.model_credential_schema?.model), - ...(draftConfig?.enabled ? [] : provider.model_credential_schema.credential_form_schemas), - ] - }, [ - providerFormSchemaPredefined, - provider.provider_credential_schema?.credential_form_schemas, - provider.supported_model_types, - provider.model_credential_schema?.credential_form_schemas, - provider.model_credential_schema?.model, - draftConfig?.enabled, - ]) - const [ - requiredFormSchemas, - defaultFormSchemaValue, - showOnVariableMap, - ] = useMemo(() => { - const requiredFormSchemas: CredentialFormSchema[] = [] - const defaultFormSchemaValue: Record = {} - const showOnVariableMap: Record = {} + const { + __authorization_name__, + __model_name, + __model_type, + ...rest + } = values + if (__model_name && __model_type) { + handleSaveCredential({ + credential_id: credential?.credential_id, + credentials: rest, + name: __authorization_name__, + model: __model_name, + model_type: __model_type, + }) + } + else { + handleSaveCredential({ + credential_id: credential?.credential_id, + credentials: rest, + name: __authorization_name__, + }) + } + }, [handleSaveCredential, credential?.credential_id, model]) - formSchemas.forEach((formSchema) => { - if (formSchema.required) - requiredFormSchemas.push(formSchema) - - if (formSchema.default) - defaultFormSchemaValue[formSchema.variable] = formSchema.default - - if (formSchema.show_on.length) { - formSchema.show_on.forEach((showOnItem) => { - if (!showOnVariableMap[showOnItem.variable]) - showOnVariableMap[showOnItem.variable] = [] - - if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) - showOnVariableMap[showOnItem.variable].push(formSchema.variable) - }) - } - - if (formSchema.type === FormTypeEnum.select || formSchema.type === FormTypeEnum.radio) { - (formSchema as (CredentialFormSchemaRadio | CredentialFormSchemaSelect)).options.forEach((option) => { - if (option.show_on.length) { - option.show_on.forEach((showOnItem) => { - if (!showOnVariableMap[showOnItem.variable]) - showOnVariableMap[showOnItem.variable] = [] - - if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) - showOnVariableMap[showOnItem.variable].push(formSchema.variable) - }) - } - }) - } - }) - - return [ - requiredFormSchemas, - defaultFormSchemaValue, - showOnVariableMap, - ] - }, [formSchemas]) - const initialFormSchemasValue: Record = useMemo(() => { - return { - ...defaultFormSchemaValue, - ...formSchemasValue, - } as unknown as Record - }, [formSchemasValue, defaultFormSchemaValue]) - const [value, setValue] = useState(initialFormSchemasValue) - useEffect(() => { - setValue(initialFormSchemasValue) - }, [initialFormSchemasValue]) - const [_, validating, validatedStatusState] = useValidate(value) - const filteredRequiredFormSchemas = requiredFormSchemas.filter((requiredFormSchema) => { - if (requiredFormSchema.show_on.length && requiredFormSchema.show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) - return true - - if (!requiredFormSchema.show_on.length) - return true - - return false - }) - - const handleValueChange = (v: FormValue) => { - setValue(v) - } - - const extendedSecretFormSchemas = useMemo( - () => - (providerFormSchemaPredefined - ? provider.provider_credential_schema.credential_form_schemas - : [ - genModelTypeFormSchema(provider.supported_model_types), - genModelNameFormSchema(provider.model_credential_schema?.model), - ...provider.model_credential_schema.credential_form_schemas, - ]).filter(({ type }) => type === FormTypeEnum.secretInput), - [ - provider.model_credential_schema?.credential_form_schemas, - provider.model_credential_schema?.model, - provider.provider_credential_schema?.credential_form_schemas, - provider.supported_model_types, - providerFormSchemaPredefined, - ], - ) - - const encodeSecretValues = useCallback((v: FormValue) => { - const result = { ...v } - extendedSecretFormSchemas.forEach(({ variable }) => { - if (result[variable] === formSchemasValue?.[variable] && result[variable] !== undefined) - result[variable] = '[__HIDDEN__]' - }) - return result - }, [extendedSecretFormSchemas, formSchemasValue]) - - const encodeConfigEntrySecretValues = useCallback((entry: ModelLoadBalancingConfigEntry) => { - const result = { ...entry } - extendedSecretFormSchemas.forEach(({ variable }) => { - if (entry.id && result.credentials[variable] === originalConfigMap[entry.id]?.credentials?.[variable]) - result.credentials[variable] = '[__HIDDEN__]' - }) - return result - }, [extendedSecretFormSchemas, originalConfigMap]) - - const handleSave = async () => { - try { - setLoading(true) - const res = await saveCredentials( - providerFormSchemaPredefined, - provider.provider, - encodeSecretValues(value), - { - ...draftConfig, - enabled: Boolean(draftConfig?.enabled), - configs: draftConfig?.configs.map(encodeConfigEntrySecretValues) || [], - }, + const modalTitle = useMemo(() => { + if (!providerFormSchemaPredefined && !model) { + return ( +
+ +
+
{t('common.modelProvider.auth.apiKeyModal.addModel')}
+
{renderI18nObject(provider.label)}
+
+
) - if (res.result === 'success') { - notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) - mutate() - onSave() - onCancel() - } } - finally { - setLoading(false) - } - } + let label = t('common.modelProvider.auth.apiKeyModal.title') - const handleRemove = async () => { - try { - setLoading(true) + if (model) + label = t('common.modelProvider.auth.addModelCredential') - const res = await removeCredentials( - providerFormSchemaPredefined, - provider.provider, - value, + return ( +
+ {label} +
+ ) + }, [providerFormSchemaPredefined, t, model, renderI18nObject]) + + const modalDesc = useMemo(() => { + if (providerFormSchemaPredefined) { + return ( +
+ {t('common.modelProvider.auth.apiKeyModal.desc')} +
) - if (res.result === 'success') { - notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) - mutate() - onSave() - onCancel() - } } - finally { - setLoading(false) - } - } - const renderTitlePrefix = () => { - const prefix = isEditMode ? t('common.operation.setup') : t('common.operation.add') - return `${prefix} ${provider.label[language] || provider.label.en_US}` - } + return null + }, [providerFormSchemaPredefined, t]) + + const modalModel = useMemo(() => { + if (model) { + return ( +
+ +
{model.model}
+ {model.model_type} +
+ ) + } + + return null + }, [model, provider]) return (
-
-
-
-
{renderTitlePrefix()}
+
+
+ +
+
+
+ {modalTitle} + {modalDesc} + {modalModel}
-
-
- + { + isLoading && ( +
+ +
+ ) + } + { + !isLoading && ( + { + return { + ...formSchema, + name: formSchema.variable, + showRadioUI: formSchema.type === FormTypeEnum.radio, + } + }) as FormSchema[]} + defaultValues={formValues} + inputClassName='justify-start' + ref={formRef} + /> + ) + }
@@ -327,7 +255,7 @@ const ModelModal: FC = ({ variant='warning' size='large' className='mr-2' - onClick={() => setShowConfirm(true)} + onClick={() => openConfirmDelete(credential, model)} > {t('common.operation.remove')} @@ -344,12 +272,7 @@ const ModelModal: FC = ({ size='large' variant='primary' onClick={handleSave} - disabled={ - loading - || filteredRequiredFormSchemas.some(item => value[item.variable] === undefined) - || (draftConfig?.enabled && (draftConfig?.configs.filter(config => config.enabled).length ?? 0) < 2) - } - + disabled={isLoading || doingAction} > {t('common.operation.save')} @@ -357,38 +280,28 @@ const ModelModal: FC = ({
- { - (validatedStatusState.status === ValidatedStatus.Error && validatedStatusState.message) - ? ( -
- - {validatedStatusState.message} -
- ) - : ( -
- - {t('common.modelProvider.encrypted.front')} - - PKCS1_OAEP - - {t('common.modelProvider.encrypted.back')} -
- ) - } +
+ + {t('common.modelProvider.encrypted.front')} + + PKCS1_OAEP + + {t('common.modelProvider.encrypted.back')} +
{ - showConfirm && ( + deleteCredentialId && ( setShowConfirm(false)} - onConfirm={handleRemove} + isDisabled={doingAction} + onCancel={closeConfirmDelete} + onConfirm={handleConfirmDelete} /> ) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx deleted file mode 100644 index d6285a784b..0000000000 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx +++ /dev/null @@ -1,348 +0,0 @@ -import type { FC } from 'react' -import { - memo, - useCallback, - useEffect, - useMemo, - useState, -} from 'react' -import { useTranslation } from 'react-i18next' -import { - RiErrorWarningFill, -} from '@remixicon/react' -import type { - CredentialFormSchema, - CredentialFormSchemaRadio, - CredentialFormSchemaSelect, - CredentialFormSchemaTextInput, - CustomConfigurationModelFixedFields, - FormValue, - ModelLoadBalancingConfigEntry, - ModelProvider, -} from '../declarations' -import { - ConfigurationMethodEnum, - FormTypeEnum, -} from '../declarations' - -import { - useLanguage, -} from '../hooks' -import { useValidate } from '../../key-validator/hooks' -import { ValidatedStatus } from '../../key-validator/declarations' -import { validateLoadBalancingCredentials } from '../utils' -import Form from './Form' -import Button from '@/app/components/base/button' -import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' -import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' -import { - PortalToFollowElem, - PortalToFollowElemContent, -} from '@/app/components/base/portal-to-follow-elem' -import { useToastContext } from '@/app/components/base/toast' -import Confirm from '@/app/components/base/confirm' - -type ModelModalProps = { - provider: ModelProvider - configurationMethod: ConfigurationMethodEnum - currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields - entry?: ModelLoadBalancingConfigEntry - onCancel: () => void - onSave: (entry: ModelLoadBalancingConfigEntry) => void - onRemove: () => void -} - -const ModelLoadBalancingEntryModal: FC = ({ - provider, - configurationMethod, - currentCustomConfigurationModelFixedFields, - entry, - onCancel, - onSave, - onRemove, -}) => { - const providerFormSchemaPredefined = configurationMethod === ConfigurationMethodEnum.predefinedModel - // const { credentials: formSchemasValue } = useProviderCredentialsAndLoadBalancing( - // provider.provider, - // configurationMethod, - // providerFormSchemaPredefined && provider.custom_configuration.status === CustomConfigurationStatusEnum.active, - // currentCustomConfigurationModelFixedFields, - // ) - const isEditMode = !!entry - const { t } = useTranslation() - const { notify } = useToastContext() - const language = useLanguage() - const [loading, setLoading] = useState(false) - const [showConfirm, setShowConfirm] = useState(false) - const formSchemas = useMemo(() => { - return [ - { - type: FormTypeEnum.textInput, - label: { - en_US: 'Config Name', - zh_Hans: '配置名称', - }, - variable: 'name', - required: true, - show_on: [], - placeholder: { - en_US: 'Enter your Config Name here', - zh_Hans: '输入配置名称', - }, - } as CredentialFormSchemaTextInput, - ...( - providerFormSchemaPredefined - ? provider.provider_credential_schema.credential_form_schemas - : provider.model_credential_schema.credential_form_schemas - ), - ] - }, [ - providerFormSchemaPredefined, - provider.provider_credential_schema?.credential_form_schemas, - provider.model_credential_schema?.credential_form_schemas, - ]) - - const [ - requiredFormSchemas, - secretFormSchemas, - defaultFormSchemaValue, - showOnVariableMap, - ] = useMemo(() => { - const requiredFormSchemas: CredentialFormSchema[] = [] - const secretFormSchemas: CredentialFormSchema[] = [] - const defaultFormSchemaValue: Record = {} - const showOnVariableMap: Record = {} - - formSchemas.forEach((formSchema) => { - if (formSchema.required) - requiredFormSchemas.push(formSchema) - - if (formSchema.type === FormTypeEnum.secretInput) - secretFormSchemas.push(formSchema) - - if (formSchema.default) - defaultFormSchemaValue[formSchema.variable] = formSchema.default - - if (formSchema.show_on.length) { - formSchema.show_on.forEach((showOnItem) => { - if (!showOnVariableMap[showOnItem.variable]) - showOnVariableMap[showOnItem.variable] = [] - - if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) - showOnVariableMap[showOnItem.variable].push(formSchema.variable) - }) - } - - if (formSchema.type === FormTypeEnum.select || formSchema.type === FormTypeEnum.radio) { - (formSchema as (CredentialFormSchemaRadio | CredentialFormSchemaSelect)).options.forEach((option) => { - if (option.show_on.length) { - option.show_on.forEach((showOnItem) => { - if (!showOnVariableMap[showOnItem.variable]) - showOnVariableMap[showOnItem.variable] = [] - - if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) - showOnVariableMap[showOnItem.variable].push(formSchema.variable) - }) - } - }) - } - }) - - return [ - requiredFormSchemas, - secretFormSchemas, - defaultFormSchemaValue, - showOnVariableMap, - ] - }, [formSchemas]) - const [initialValue, setInitialValue] = useState() - useEffect(() => { - if (entry && !initialValue) { - setInitialValue({ - ...defaultFormSchemaValue, - ...entry.credentials, - id: entry.id, - name: entry.name, - } as Record) - } - }, [entry, defaultFormSchemaValue, initialValue]) - const formSchemasValue = useMemo(() => ({ - ...currentCustomConfigurationModelFixedFields, - ...initialValue, - }), [currentCustomConfigurationModelFixedFields, initialValue]) - const initialFormSchemasValue: Record = useMemo(() => { - return { - ...defaultFormSchemaValue, - ...formSchemasValue, - } as Record - }, [formSchemasValue, defaultFormSchemaValue]) - const [value, setValue] = useState(initialFormSchemasValue) - useEffect(() => { - setValue(initialFormSchemasValue) - }, [initialFormSchemasValue]) - const [_, validating, validatedStatusState] = useValidate(value) - const filteredRequiredFormSchemas = requiredFormSchemas.filter((requiredFormSchema) => { - if (requiredFormSchema.show_on.length && requiredFormSchema.show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) - return true - - if (!requiredFormSchema.show_on.length) - return true - - return false - }) - const getSecretValues = useCallback((v: FormValue) => { - return secretFormSchemas.reduce((prev, next) => { - if (isEditMode && v[next.variable] && v[next.variable] === initialFormSchemasValue[next.variable]) - prev[next.variable] = '[__HIDDEN__]' - - return prev - }, {} as Record) - }, [initialFormSchemasValue, isEditMode, secretFormSchemas]) - - // const handleValueChange = ({ __model_type, __model_name, ...v }: FormValue) => { - const handleValueChange = (v: FormValue) => { - setValue(v) - } - const handleSave = async () => { - try { - setLoading(true) - - const res = await validateLoadBalancingCredentials( - providerFormSchemaPredefined, - provider.provider, - { - ...value, - ...getSecretValues(value), - }, - entry?.id, - ) - if (res.status === ValidatedStatus.Success) { - // notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) - const { __model_type, __model_name, name, ...credentials } = value - onSave({ - ...(entry || {}), - name: name as string, - credentials: credentials as Record, - }) - // onCancel() - } - else { - notify({ type: 'error', message: res.message || '' }) - } - } - finally { - setLoading(false) - } - } - - const handleRemove = () => { - onRemove?.() - } - - return ( - - -
-
-
-
-
{t(isEditMode ? 'common.modelProvider.editConfig' : 'common.modelProvider.addConfig')}
-
- -
- { - (provider.help && (provider.help.title || provider.help.url)) - ? ( - !provider.help.url && e.preventDefault()} - > - {provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US} - - - ) - :
- } -
- { - isEditMode && ( - - ) - } - - -
-
-
-
- { - (validatedStatusState.status === ValidatedStatus.Error && validatedStatusState.message) - ? ( -
- - {validatedStatusState.message} -
- ) - : ( -
- - {t('common.modelProvider.encrypted.front')} - - PKCS1_OAEP - - {t('common.modelProvider.encrypted.back')} -
- ) - } -
-
- { - showConfirm && ( - setShowConfirm(false)} - onConfirm={handleRemove} - /> - ) - } -
- - - ) -} - -export default memo(ModelLoadBalancingEntryModal) diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx index 822df5f726..d57288db3f 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx @@ -1,7 +1,8 @@ -import type { FC } from 'react' +import { useMemo } from 'react' import { useTranslation } from 'react-i18next' -import { RiEqualizer2Line } from '@remixicon/react' -import type { ModelProvider } from '../declarations' +import type { + ModelProvider, +} from '../declarations' import { ConfigurationMethodEnum, CustomConfigurationStatusEnum, @@ -15,19 +16,19 @@ import PrioritySelector from './priority-selector' import PriorityUseTip from './priority-use-tip' import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './index' import Indicator from '@/app/components/header/indicator' -import Button from '@/app/components/base/button' import { changeModelProviderPriority } from '@/service/common' import { useToastContext } from '@/app/components/base/toast' import { useEventEmitterContextContext } from '@/context/event-emitter' +import cn from '@/utils/classnames' +import { useCredentialStatus } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks' +import { ConfigProvider } from '@/app/components/header/account-setting/model-provider-page/model-auth' type CredentialPanelProps = { provider: ModelProvider - onSetup: () => void } -const CredentialPanel: FC = ({ +const CredentialPanel = ({ provider, - onSetup, -}) => { +}: CredentialPanelProps) => { const { t } = useTranslation() const { notify } = useToastContext() const { eventEmitter } = useEventEmitterContextContext() @@ -38,6 +39,13 @@ const CredentialPanel: FC = ({ const priorityUseType = provider.preferred_provider_type const isCustomConfigured = customConfig.status === CustomConfigurationStatusEnum.active const configurateMethods = provider.configurate_methods + const { + hasCredential, + authorized, + authRemoved, + current_credential_name, + notAllowedToUse, + } = useCredentialStatus(provider) const handleChangePriority = async (key: PreferredProviderTypeEnum) => { const res = await changeModelProviderPriority({ @@ -61,25 +69,50 @@ const CredentialPanel: FC = ({ } as any) } } + const credentialLabel = useMemo(() => { + if (!hasCredential) + return t('common.modelProvider.auth.unAuthorized') + if (authorized) + return current_credential_name + if (authRemoved) + return t('common.modelProvider.auth.authRemoved') + + return '' + }, [authorized, authRemoved, current_credential_name, hasCredential]) + + const color = useMemo(() => { + if (authRemoved) + return 'red' + if (notAllowedToUse) + return 'gray' + return 'green' + }, [authRemoved, notAllowedToUse]) return ( <> { provider.provider_credential_schema && ( -
-
- API-KEY - +
+
+
+ {credentialLabel} +
+
- + { systemConfig.enabled && isCustomConfigured && ( void } const ProviderAddedCard: FC = ({ notConfigured, provider, - onOpenModal, }) => { const { t } = useTranslation() const { eventEmitter } = useEventEmitterContextContext() @@ -114,7 +111,6 @@ const ProviderAddedCard: FC = ({ { showCredential && ( onOpenModal(ConfigurationMethodEnum.predefinedModel)} provider={provider} /> ) @@ -159,9 +155,9 @@ const ProviderAddedCard: FC = ({ )} { configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && ( - onOpenModal(ConfigurationMethodEnum.customizableModel)} - className='flex' + ) } @@ -174,7 +170,6 @@ const ProviderAddedCard: FC = ({ provider={provider} models={modelList} onCollapse={() => setCollapsed(true)} - onConfig={currentCustomConfigurationModelFixedFields => onOpenModal(ConfigurationMethodEnum.customizableModel, currentCustomConfigurationModelFixedFields)} onChange={(provider: string) => getModelList(provider)} /> ) diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx index 8908d9a039..bcd4832443 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx @@ -1,31 +1,29 @@ import { memo, useCallback } from 'react' import { useTranslation } from 'react-i18next' import { useDebounceFn } from 'ahooks' -import type { CustomConfigurationModelFixedFields, ModelItem, ModelProvider } from '../declarations' -import { ConfigurationMethodEnum, ModelStatusEnum } from '../declarations' -import ModelBadge from '../model-badge' +import type { ModelItem, ModelProvider } from '../declarations' +import { ModelStatusEnum } from '../declarations' import ModelIcon from '../model-icon' import ModelName from '../model-name' import classNames from '@/utils/classnames' -import Button from '@/app/components/base/button' import { Balance } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' -import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import { useProviderContext, useProviderContextSelector } from '@/context/provider-context' import { disableModel, enableModel } from '@/service/common' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' +import { ConfigModel } from '../model-auth' +import Badge from '@/app/components/base/badge' export type ModelListItemProps = { model: ModelItem provider: ModelProvider isConfigurable: boolean - onConfig: (currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => void onModifyLoadBalancing?: (model: ModelItem) => void } -const ModelListItem = ({ model, provider, isConfigurable, onConfig, onModifyLoadBalancing }: ModelListItemProps) => { +const ModelListItem = ({ model, provider, isConfigurable, onModifyLoadBalancing }: ModelListItemProps) => { const { t } = useTranslation() const { plan } = useProviderContext() const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) @@ -46,7 +44,7 @@ const ModelListItem = ({ model, provider, isConfigurable, onConfig, onModifyLoad return (
- {modelLoadBalancingEnabled && !model.deprecated && model.load_balancing_enabled && ( - - - {t('common.modelProvider.loadBalancingHeadline')} - - )}
+ {modelLoadBalancingEnabled && !model.deprecated && model.load_balancing_enabled && !model.has_invalid_load_balancing_configs && ( + + + + )} { - model.fetch_from === ConfigurationMethodEnum.customizableModel - ? (isCurrentWorkspaceManager && ( - - )) - : (isCurrentWorkspaceManager && (modelLoadBalancingEnabled || plan.type === Plan.sandbox) && !model.deprecated && [ModelStatusEnum.active, ModelStatusEnum.disabled].includes(model.status)) - ? ( - - ) - : null + (isCurrentWorkspaceManager && (modelLoadBalancingEnabled || plan.type === Plan.sandbox) && !model.deprecated && [ModelStatusEnum.active, ModelStatusEnum.disabled].includes(model.status)) && ( + onModifyLoadBalancing?.(model)} + loadBalancingEnabled={model.load_balancing_enabled} + loadBalancingInvalid={model.has_invalid_load_balancing_configs} + credentialRemoved={model.status === ModelStatusEnum.credentialRemoved} + /> + ) } { model.deprecated diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx index 699be6edda..8d902043ff 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx @@ -5,7 +5,7 @@ import { RiArrowRightSLine, } from '@remixicon/react' import type { - CustomConfigurationModelFixedFields, + Credential, ModelItem, ModelProvider, } from '../declarations' @@ -13,34 +13,33 @@ import { ConfigurationMethodEnum, } from '../declarations' // import Tab from './tab' -import AddModelButton from './add-model-button' import ModelListItem from './model-list-item' import { useModalContextSelector } from '@/context/modal-context' import { useAppContext } from '@/context/app-context' +import { AddCustomModel } from '@/app/components/header/account-setting/model-provider-page/model-auth' type ModelListProps = { provider: ModelProvider models: ModelItem[] onCollapse: () => void - onConfig: (currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => void onChange?: (provider: string) => void } const ModelList: FC = ({ provider, models, onCollapse, - onConfig, onChange, }) => { const { t } = useTranslation() const configurativeMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote) const { isCurrentWorkspaceManager } = useAppContext() const isConfigurable = configurativeMethods.includes(ConfigurationMethodEnum.customizableModel) - const setShowModelLoadBalancingModal = useModalContextSelector(state => state.setShowModelLoadBalancingModal) - const onModifyLoadBalancing = useCallback((model: ModelItem) => { + const onModifyLoadBalancing = useCallback((model: ModelItem, credential?: Credential) => { setShowModelLoadBalancingModal({ provider, + credential, + configurateMethod: model.fetch_from, model: model!, open: !!model, onClose: () => setShowModelLoadBalancingModal(null), @@ -65,17 +64,14 @@ const ModelList: FC = ({ - {/* { - isConfigurable && canSystemConfig && ( - - {}} /> - - ) - } */} { isConfigurable && isCurrentWorkspaceManager && (
- onConfig()} /> +
) } @@ -83,12 +79,11 @@ const ModelList: FC = ({ { models.map(model => ( diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx index 1a3039659a..f92c188aa7 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx @@ -1,24 +1,35 @@ import type { Dispatch, SetStateAction } from 'react' -import { useCallback } from 'react' +import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' import { RiDeleteBinLine, + RiEqualizer2Line, } from '@remixicon/react' -import type { ConfigurationMethodEnum, CustomConfigurationModelFixedFields, ModelLoadBalancingConfig, ModelLoadBalancingConfigEntry, ModelProvider } from '../declarations' +import type { + Credential, + CustomConfigurationModelFixedFields, + CustomModelCredential, + ModelCredential, + ModelLoadBalancingConfig, + ModelLoadBalancingConfigEntry, + ModelProvider, +} from '../declarations' +import { ConfigurationMethodEnum } from '../declarations' import Indicator from '../../../indicator' import CooldownTimer from './cooldown-timer' import classNames from '@/utils/classnames' import Tooltip from '@/app/components/base/tooltip' import Switch from '@/app/components/base/switch' import { Balance } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' -import { Edit02, Plus02 } from '@/app/components/base/icons/src/vender/line/general' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' -import { useModalContextSelector } from '@/context/modal-context' import UpgradeBtn from '@/app/components/billing/upgrade-btn' import s from '@/app/components/custom/style.module.css' import GridMask from '@/app/components/base/grid-mask' import { useProviderContextSelector } from '@/context/provider-context' import { IS_CE_EDITION } from '@/config' +import { AddCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth' +import { useModelModalHandler } from '@/app/components/header/account-setting/model-provider-page/hooks' +import Badge from '@/app/components/base/badge/index' export type ModelLoadBalancingConfigsProps = { draftConfig?: ModelLoadBalancingConfig @@ -28,19 +39,27 @@ export type ModelLoadBalancingConfigsProps = { currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields withSwitch?: boolean className?: string + modelCredential: ModelCredential + onUpdate?: () => void + model: CustomModelCredential } const ModelLoadBalancingConfigs = ({ draftConfig, setDraftConfig, provider, + model, configurationMethod, currentCustomConfigurationModelFixedFields, withSwitch = false, className, + modelCredential, + onUpdate, }: ModelLoadBalancingConfigsProps) => { const { t } = useTranslation() + const providerFormSchemaPredefined = configurationMethod === ConfigurationMethodEnum.predefinedModel const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) + const handleOpenModal = useModelModalHandler() const updateConfigEntry = useCallback( ( @@ -65,6 +84,21 @@ const ModelLoadBalancingConfigs = ({ [setDraftConfig], ) + const addConfigEntry = useCallback((credential: Credential) => { + setDraftConfig((prev: any) => { + if (!prev) + return prev + return { + ...prev, + configs: [...prev.configs, { + credential_id: credential.credential_id, + enabled: true, + name: credential.credential_name, + }], + } + }) + }, [setDraftConfig]) + const toggleModalBalancing = useCallback((enabled: boolean) => { if ((modelLoadBalancingEnabled || !enabled) && draftConfig) { setDraftConfig({ @@ -81,54 +115,6 @@ const ModelLoadBalancingConfigs = ({ })) }, [updateConfigEntry]) - const setShowModelLoadBalancingEntryModal = useModalContextSelector(state => state.setShowModelLoadBalancingEntryModal) - - const toggleEntryModal = useCallback((index?: number, entry?: ModelLoadBalancingConfigEntry) => { - setShowModelLoadBalancingEntryModal({ - payload: { - currentProvider: provider, - currentConfigurationMethod: configurationMethod, - currentCustomConfigurationModelFixedFields, - entry, - index, - }, - onSaveCallback: ({ entry: result }) => { - if (entry) { - // edit - setDraftConfig(prev => ({ - ...prev, - enabled: !!prev?.enabled, - configs: prev?.configs.map((config, i) => i === index ? result! : config) || [], - })) - } - else { - // add - setDraftConfig(prev => ({ - ...prev, - enabled: !!prev?.enabled, - configs: (prev?.configs || []).concat([{ ...result!, enabled: true }]), - })) - } - }, - onRemoveCallback: ({ index }) => { - if (index !== undefined && (draftConfig?.configs?.length ?? 0) > index) { - setDraftConfig(prev => ({ - ...prev, - enabled: !!prev?.enabled, - configs: prev?.configs.filter((_, i) => i !== index) || [], - })) - } - }, - }) - }, [ - configurationMethod, - currentCustomConfigurationModelFixedFields, - draftConfig?.configs?.length, - provider, - setDraftConfig, - setShowModelLoadBalancingEntryModal, - ]) - const clearCountdown = useCallback((index: number) => { updateConfigEntry(index, ({ ttl: _, ...entry }) => { return { @@ -138,6 +124,12 @@ const ModelLoadBalancingConfigs = ({ }) }, [updateConfigEntry]) + const validDraftConfigList = useMemo(() => { + if (!draftConfig) + return [] + return draftConfig.configs + }, [draftConfig]) + if (!draftConfig) return null @@ -181,8 +173,9 @@ const ModelLoadBalancingConfigs = ({
{draftConfig.enabled && (
- {draftConfig.configs.map((config, index) => { + {validDraftConfigList.map((config, index) => { const isProviderManaged = config.name === '__inherit__' + const credential = modelCredential.available_credentials.find(c => c.credential_id === config.credential_id) return (
@@ -200,54 +193,81 @@ const ModelLoadBalancingConfigs = ({
{isProviderManaged ? t('common.modelProvider.defaultConfig') : config.name}
- {isProviderManaged && ( - {t('common.modelProvider.providerManaged')} + {isProviderManaged && providerFormSchemaPredefined && ( + {t('common.modelProvider.providerManaged')} )} + { + credential?.from_enterprise && ( + Enterprise + ) + }
{!isProviderManaged && ( <>
- toggleEntryModal(index, config)} - > - - + { + config.credential_id && !credential?.not_allowed_to_use && !credential?.from_enterprise && ( + { + handleOpenModal( + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, + configurationMethod === ConfigurationMethodEnum.customizableModel, + (config.credential_id && config.name) ? { + credential_id: config.credential_id, + credential_name: config.name, + } : undefined, + model, + ) + }} + > + + + ) + } updateConfigEntry(index, () => undefined)} > -
)} - toggleConfigEntryEnabled(index, value)} - /> + { + (config.credential_id || config.name === '__inherit__') && ( + <> + + toggleConfigEntryEnabled(index, value)} + disabled={credential?.not_allowed_to_use} + /> + + ) + }
) })} - -
toggleEntryModal()} - > -
- {t('common.modelProvider.addConfig')} -
-
+
)} { - draftConfig.enabled && draftConfig.configs.length < 2 && ( -
+ draftConfig.enabled && validDraftConfigList.length < 2 && ( +
{t('common.modelProvider.loadBalancingLeastKeyWarning')}
diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx index 9fb07401f7..1d6db30c4c 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx @@ -1,40 +1,69 @@ import { memo, useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import useSWR from 'swr' -import type { ModelItem, ModelLoadBalancingConfig, ModelLoadBalancingConfigEntry, ModelProvider } from '../declarations' -import { FormTypeEnum } from '../declarations' +import type { + Credential, + ModelItem, + ModelLoadBalancingConfig, + ModelLoadBalancingConfigEntry, + ModelProvider, +} from '../declarations' +import { + ConfigurationMethodEnum, + FormTypeEnum, +} from '../declarations' import ModelIcon from '../model-icon' import ModelName from '../model-name' -import { savePredefinedLoadBalancingConfig } from '../utils' import ModelLoadBalancingConfigs from './model-load-balancing-configs' import classNames from '@/utils/classnames' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' -import { fetchModelLoadBalancingConfig } from '@/service/common' import Loading from '@/app/components/base/loading' import { useToastContext } from '@/app/components/base/toast' +import { SwitchCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth' +import { + useGetModelCredential, + useUpdateModelLoadBalancingConfig, +} from '@/service/use-models' export type ModelLoadBalancingModalProps = { provider: ModelProvider + configurateMethod: ConfigurationMethodEnum model: ModelItem + credential?: Credential open?: boolean onClose?: () => void onSave?: (provider: string) => void } // model balancing config modal -const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSave }: ModelLoadBalancingModalProps) => { +const ModelLoadBalancingModal = ({ + provider, + configurateMethod, + model, + credential, + open = false, + onClose, + onSave, +}: ModelLoadBalancingModalProps) => { const { t } = useTranslation() const { notify } = useToastContext() const [loading, setLoading] = useState(false) - - const { data, mutate } = useSWR( - `/workspaces/current/model-providers/${provider.provider}/models/credentials?model=${model.model}&model_type=${model.model_type}`, - fetchModelLoadBalancingConfig, - ) - - const originalConfig = data?.load_balancing + const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel + const configFrom = providerFormSchemaPredefined ? 'predefined-model' : 'custom-model' + const { + isLoading, + data, + refetch, + } = useGetModelCredential(true, provider.provider, credential?.credential_id, model.model, model.model_type, configFrom) + const modelCredential = data + const { + load_balancing, + current_credential_id, + available_credentials, + current_credential_name, + } = modelCredential ?? {} + const originalConfig = load_balancing const [draftConfig, setDraftConfig] = useState() const originalConfigMap = useMemo(() => { if (!originalConfig) @@ -60,10 +89,17 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav }, [draftConfig]) const extendedSecretFormSchemas = useMemo( - () => provider.provider_credential_schema.credential_form_schemas.filter( - ({ type }) => type === FormTypeEnum.secretInput, - ), - [provider.provider_credential_schema.credential_form_schemas], + () => { + if (providerFormSchemaPredefined) { + return provider?.provider_credential_schema?.credential_form_schemas?.filter( + ({ type }) => type === FormTypeEnum.secretInput, + ) ?? [] + } + return provider?.model_credential_schema?.credential_form_schemas?.filter( + ({ type }) => type === FormTypeEnum.secretInput, + ) ?? [] + }, + [provider?.model_credential_schema?.credential_form_schemas, provider?.provider_credential_schema?.credential_form_schemas, providerFormSchemaPredefined], ) const encodeConfigEntrySecretValues = useCallback((entry: ModelLoadBalancingConfigEntry) => { @@ -75,25 +111,34 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav return result }, [extendedSecretFormSchemas, originalConfigMap]) + const { mutateAsync: updateModelLoadBalancingConfig } = useUpdateModelLoadBalancingConfig(provider.provider) + const initialCustomModelCredential = useMemo(() => { + if (!current_credential_id) + return undefined + return { + credential_id: current_credential_id, + credential_name: current_credential_name, + } + }, [current_credential_id, current_credential_name]) + const [customModelCredential, setCustomModelCredential] = useState(initialCustomModelCredential) const handleSave = async () => { try { setLoading(true) - const res = await savePredefinedLoadBalancingConfig( - provider.provider, - ({ - ...(data?.credentials ?? {}), - __model_type: model.model_type, - __model_name: model.model, - }), + const res = await updateModelLoadBalancingConfig( { - ...draftConfig, - enabled: Boolean(draftConfig?.enabled), - configs: draftConfig!.configs.map(encodeConfigEntrySecretValues), + credential_id: customModelCredential?.credential_id || current_credential_id, + config_from: configFrom, + model: model.model, + model_type: model.model_type, + load_balancing: { + ...draftConfig, + configs: draftConfig!.configs.map(encodeConfigEntrySecretValues), + enabled: Boolean(draftConfig?.enabled), + }, }, ) if (res.result === 'success') { notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) - mutate() onSave?.(provider.provider) onClose?.() } @@ -110,7 +155,11 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav className='w-[640px] max-w-none px-8 pt-8' title={
-
{t('common.modelProvider.configLoadBalancing')}
+
{ + draftConfig?.enabled + ? t('common.modelProvider.auth.configLoadBalancing') + : t('common.modelProvider.auth.configModel') + }
{Boolean(model) && (
-
{t('common.modelProvider.providerManaged')}
-
{t('common.modelProvider.providerManagedDescription')}
+
{ + providerFormSchemaPredefined + ? t('common.modelProvider.auth.providerManaged') + : t('common.modelProvider.auth.specifyModelCredential') + }
+
{ + providerFormSchemaPredefined + ? t('common.modelProvider.auth.providerManagedTip') + : t('common.modelProvider.auth.specifyModelCredentialTip') + }
+ { + !providerFormSchemaPredefined && ( + + ) + }
- - + { + modelCredential && ( + + ) + }
@@ -176,6 +253,7 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav disabled={ loading || (draftConfig?.enabled && (draftConfig?.configs.filter(config => config.enabled).length ?? 0) < 2) + || isLoading } >{t('common.operation.save')}
diff --git a/web/app/components/header/account-setting/model-provider-page/utils.ts b/web/app/components/header/account-setting/model-provider-page/utils.ts index 9056afe69b..f577a536dc 100644 --- a/web/app/components/header/account-setting/model-provider-page/utils.ts +++ b/web/app/components/header/account-setting/model-provider-page/utils.ts @@ -1,6 +1,5 @@ import { ValidatedStatus } from '../key-validator/declarations' import type { - CredentialFormSchemaRadio, CredentialFormSchemaTextInput, FormValue, ModelLoadBalancingConfig, @@ -82,12 +81,14 @@ export const saveCredentials = async (predefined: boolean, provider: string, v: let body, url if (predefined) { + const { __authorization_name__, ...rest } = v body = { config_from: ConfigurationMethodEnum.predefinedModel, - credentials: v, + credentials: rest, load_balancing: loadBalancing, + name: __authorization_name__, } - url = `/workspaces/current/model-providers/${provider}` + url = `/workspaces/current/model-providers/${provider}/credentials` } else { const { __model_name, __model_type, ...credentials } = v @@ -117,12 +118,17 @@ export const savePredefinedLoadBalancingConfig = async (provider: string, v: For return setModelProvider({ url, body }) } -export const removeCredentials = async (predefined: boolean, provider: string, v: FormValue) => { +export const removeCredentials = async (predefined: boolean, provider: string, v: FormValue, credentialId?: string) => { let url = '' let body if (predefined) { - url = `/workspaces/current/model-providers/${provider}` + url = `/workspaces/current/model-providers/${provider}/credentials` + if (credentialId) { + body = { + credential_id: credentialId, + } + } } else { if (v) { @@ -174,7 +180,7 @@ export const genModelTypeFormSchema = (modelTypes: ModelTypeEnum[]) => { show_on: [], } }), - } as CredentialFormSchemaRadio + } as any } export const genModelNameFormSchema = (model?: Pick) => { @@ -191,5 +197,5 @@ export const genModelNameFormSchema = (model?: Pick void + notAllowCustomCredential?: boolean } const Authorize = ({ pluginPayload, @@ -26,6 +29,7 @@ const Authorize = ({ canApiKey, disabled, onUpdate, + notAllowCustomCredential, }: AuthorizeProps) => { const { t } = useTranslation() const oAuthButtonProps: AddOAuthButtonProps = useMemo(() => { @@ -62,18 +66,54 @@ const Authorize = ({ } }, [canOAuth, theme, pluginPayload, t]) + const OAuthButton = useMemo(() => { + const Item = ( +
+ +
+ ) + + if (notAllowCustomCredential) { + return ( + + {Item} + + ) + } + return Item + }, [notAllowCustomCredential, oAuthButtonProps, disabled, onUpdate, t]) + + const ApiKeyButton = useMemo(() => { + const Item = ( +
+ +
+ ) + + if (notAllowCustomCredential) { + return ( + + {Item} + + ) + } + return Item + }, [notAllowCustomCredential, apiKeyButtonProps, disabled, onUpdate, t]) + return ( <>
{ canOAuth && ( -
- -
+ OAuthButton ) } { @@ -87,13 +127,7 @@ const Authorize = ({ } { canApiKey && ( -
- -
+ ApiKeyButton ) }
diff --git a/web/app/components/plugins/plugin-auth/authorized-in-node.tsx b/web/app/components/plugins/plugin-auth/authorized-in-node.tsx index 79189fa585..79eef66451 100644 --- a/web/app/components/plugins/plugin-auth/authorized-in-node.tsx +++ b/web/app/components/plugins/plugin-auth/authorized-in-node.tsx @@ -35,10 +35,13 @@ const AuthorizedInNode = ({ credentials, disabled, invalidPluginCredentialInfo, + notAllowCustomCredential, } = usePluginAuth(pluginPayload, isOpen || !!credentialId) const renderTrigger = useCallback((open?: boolean) => { let label = '' let removed = false + let unavailable = false + let color = 'green' if (!credentialId) { label = t('plugin.auth.workspaceDefault') } @@ -46,6 +49,12 @@ const AuthorizedInNode = ({ const credential = credentials.find(c => c.id === credentialId) label = credential ? credential.name : t('plugin.auth.authRemoved') removed = !credential + unavailable = !!credential?.not_allowed_to_use && !credential?.from_enterprise + + if (removed) + color = 'red' + else if (unavailable) + color = 'gray' } return ( ) @@ -294,18 +302,24 @@ const Authorized = ({ ) }
-
-
- -
+ { + !notAllowCustomCredential && ( + <> +
+
+ +
+ + ) + }
diff --git a/web/app/components/plugins/plugin-auth/authorized/item.tsx b/web/app/components/plugins/plugin-auth/authorized/item.tsx index 5508bcc324..f8a1033de7 100644 --- a/web/app/components/plugins/plugin-auth/authorized/item.tsx +++ b/web/app/components/plugins/plugin-auth/authorized/item.tsx @@ -61,14 +61,19 @@ const Item = ({ return !(disableRename && disableEdit && disableDelete && disableSetDefault) }, [disableRename, disableEdit, disableDelete, disableSetDefault]) - return ( + const CredentialItem = (
onItemClick?.(credential.id === '__workspace_default__' ? '' : credential.id)} + onClick={() => { + if (credential.not_allowed_to_use || disabled) + return + onItemClick?.(credential.id === '__workspace_default__' ? '' : credential.id) + }} > { renaming && ( @@ -121,7 +126,10 @@ const Item = ({
) } - +
) } + { + credential.from_enterprise && ( + + Enterprise + + ) + } { showAction && !renaming && (
{ - !credential.is_default && !disableSetDefault && ( + !credential.is_default && !disableSetDefault && !credential.not_allowed_to_use && ( ) @@ -93,6 +104,7 @@ const PluginAuthInAgent = ({ canApiKey={canApiKey} disabled={disabled} onUpdate={invalidPluginCredentialInfo} + notAllowCustomCredential={notAllowCustomCredential} /> ) } @@ -113,6 +125,7 @@ const PluginAuthInAgent = ({ onOpenChange={setIsOpen} selectedCredentialId={credentialId || '__workspace_default__'} onUpdate={invalidPluginCredentialInfo} + notAllowCustomCredential={notAllowCustomCredential} /> ) } diff --git a/web/app/components/plugins/plugin-auth/plugin-auth.tsx b/web/app/components/plugins/plugin-auth/plugin-auth.tsx index 76b405a750..a9bb287cdf 100644 --- a/web/app/components/plugins/plugin-auth/plugin-auth.tsx +++ b/web/app/components/plugins/plugin-auth/plugin-auth.tsx @@ -22,6 +22,7 @@ const PluginAuth = ({ credentials, disabled, invalidPluginCredentialInfo, + notAllowCustomCredential, } = usePluginAuth(pluginPayload, !!pluginPayload.provider) return ( @@ -34,6 +35,7 @@ const PluginAuth = ({ canApiKey={canApiKey} disabled={disabled} onUpdate={invalidPluginCredentialInfo} + notAllowCustomCredential={notAllowCustomCredential} /> ) } @@ -46,6 +48,7 @@ const PluginAuth = ({ canApiKey={canApiKey} disabled={disabled} onUpdate={invalidPluginCredentialInfo} + notAllowCustomCredential={notAllowCustomCredential} /> ) } diff --git a/web/app/components/plugins/plugin-auth/types.ts b/web/app/components/plugins/plugin-auth/types.ts index ad41733bde..1fb2c1a531 100644 --- a/web/app/components/plugins/plugin-auth/types.ts +++ b/web/app/components/plugins/plugin-auth/types.ts @@ -22,4 +22,6 @@ export type Credential = { is_default: boolean credentials?: Record isWorkspaceDefault?: boolean + from_enterprise?: boolean + not_allowed_to_use?: boolean } diff --git a/web/context/modal-context.tsx b/web/context/modal-context.tsx index f1e5bb044f..dac9ef30d5 100644 --- a/web/context/modal-context.tsx +++ b/web/context/modal-context.tsx @@ -6,7 +6,9 @@ import { createContext, useContext, useContextSelector } from 'use-context-selec import { useRouter, useSearchParams } from 'next/navigation' import type { ConfigurationMethodEnum, + Credential, CustomConfigurationModelFixedFields, + CustomModel, ModelLoadBalancingConfigEntry, ModelProvider, } from '@/app/components/header/account-setting/model-provider-page/declarations' @@ -55,9 +57,6 @@ const ExternalAPIModal = dynamic(() => import('@/app/components/datasets/externa const ModelLoadBalancingModal = dynamic(() => import('@/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal'), { ssr: false, }) -const ModelLoadBalancingEntryModal = dynamic(() => import('@/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal'), { - ssr: false, -}) const OpeningSettingModal = dynamic(() => import('@/app/components/base/features/new-feature-panel/conversation-opener/modal'), { ssr: false, }) @@ -84,6 +83,9 @@ export type ModelModalType = { currentProvider: ModelProvider currentConfigurationMethod: ConfigurationMethodEnum currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields + isModelCredential?: boolean + credential?: Credential + model?: CustomModel } export type LoadBalancingEntryModalType = ModelModalType & { entry?: ModelLoadBalancingConfigEntry @@ -100,7 +102,6 @@ export type ModalContextState = { setShowModelModal: Dispatch | null>> setShowExternalKnowledgeAPIModal: Dispatch | null>> setShowModelLoadBalancingModal: Dispatch> - setShowModelLoadBalancingEntryModal: Dispatch | null>> setShowOpeningModal: Dispatch({ setShowModelModal: noop, setShowExternalKnowledgeAPIModal: noop, setShowModelLoadBalancingModal: noop, - setShowModelLoadBalancingEntryModal: noop, setShowOpeningModal: noop, setShowUpdatePluginModal: noop, setShowEducationExpireNoticeModal: noop, @@ -145,7 +145,6 @@ export const ModalContextProvider = ({ const [showModelModal, setShowModelModal] = useState | null>(null) const [showExternalKnowledgeAPIModal, setShowExternalKnowledgeAPIModal] = useState | null>(null) const [showModelLoadBalancingModal, setShowModelLoadBalancingModal] = useState(null) - const [showModelLoadBalancingEntryModal, setShowModelLoadBalancingEntryModal] = useState | null>(null) const [showOpeningModal, setShowOpeningModal] = useState { - showModelLoadBalancingEntryModal?.onCancelCallback?.() - setShowModelLoadBalancingEntryModal(null) - }, [showModelLoadBalancingEntryModal]) - const handleCancelOpeningModal = useCallback(() => { setShowOpeningModal(null) if (showOpeningModal?.onCancelCallback) showOpeningModal.onCancelCallback() }, [showOpeningModal]) - const handleSaveModelLoadBalancingEntryModal = useCallback((entry: ModelLoadBalancingConfigEntry) => { - showModelLoadBalancingEntryModal?.onSaveCallback?.({ - ...showModelLoadBalancingEntryModal.payload, - entry, - }) - setShowModelLoadBalancingEntryModal(null) - }, [showModelLoadBalancingEntryModal]) - - const handleRemoveModelLoadBalancingEntry = useCallback(() => { - showModelLoadBalancingEntryModal?.onRemoveCallback?.(showModelLoadBalancingEntryModal.payload) - setShowModelLoadBalancingEntryModal(null) - }, [showModelLoadBalancingEntryModal]) - const handleSaveApiBasedExtension = (newApiBasedExtension: ApiBasedExtension) => { if (showApiBasedExtensionModal?.onSaveCallback) showApiBasedExtensionModal.onSaveCallback(newApiBasedExtension) @@ -277,7 +258,6 @@ export const ModalContextProvider = ({ setShowModelModal, setShowExternalKnowledgeAPIModal, setShowModelLoadBalancingModal, - setShowModelLoadBalancingEntryModal, setShowOpeningModal, setShowUpdatePluginModal, setShowEducationExpireNoticeModal, @@ -346,6 +326,9 @@ export const ModalContextProvider = ({ provider={showModelModal.payload.currentProvider} configurateMethod={showModelModal.payload.currentConfigurationMethod} currentCustomConfigurationModelFixedFields={showModelModal.payload.currentCustomConfigurationModelFixedFields} + isModelCredential={showModelModal.payload.isModelCredential} + credential={showModelModal.payload.credential} + model={showModelModal.payload.model} onCancel={handleCancelModelModal} onSave={handleSaveModelModal} /> @@ -368,19 +351,6 @@ export const ModalContextProvider = ({ ) } - { - !!showModelLoadBalancingEntryModal && ( - - ) - } {showOpeningModal && ( = ({ children }) => { // Compute shareCode directly const shareCode = getShareCodeFromRedirectUrl(redirectUrlParam) || getShareCodeFromPathname(pathname) - updateShareCode(shareCode) + useEffect(() => { + updateShareCode(shareCode) + }, [shareCode, updateShareCode]) const { isFetching, data: accessModeResult } = useGetWebAppAccessModeByCode(shareCode) const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(false) diff --git a/web/i18n/de-DE/common.ts b/web/i18n/de-DE/common.ts index d8e010ce0f..57abe75f87 100644 --- a/web/i18n/de-DE/common.ts +++ b/web/i18n/de-DE/common.ts @@ -60,6 +60,7 @@ const translation = { format: 'Format', selectAll: 'Alles auswählen', deSelectAll: 'Alle abwählen', + config: 'Konfiguration', }, placeholder: { input: 'Bitte eingeben', @@ -468,6 +469,28 @@ const translation = { installProvider: 'Installieren von Modellanbietern', toBeConfigured: 'Zu konfigurieren', emptyProviderTitle: 'Modellanbieter nicht eingerichtet', + auth: { + apiKeyModal: { + addModel: 'Modell hinzufügen', + title: 'API-Schlüssel-Autorisierungskonfiguration', + desc: 'Nachdem die Anmeldeinformationen konfiguriert wurden, können alle Mitglieder des Arbeitsbereichs dieses Modell beim Orchestrieren von Anwendungen verwenden.', + }, + specifyModelCredential: 'Angeben von Modellanmeldeinformationen', + addNewModel: 'Neues Modell hinzufügen', + addCredential: 'Anmeldeinformationen hinzufügen', + providerManaged: 'Anbieter verwaltet', + addApiKey: 'API-Schlüssel hinzufügen', + apiKeys: 'API-Schlüssel', + unAuthorized: 'Unbefugt', + authorizationError: 'Autorisierungsfehler', + modelCredentials: 'Modellanmeldeinformationen', + configModel: 'Konfigurationsmodell', + authRemoved: 'Die Authentifizierung wurde entfernt.', + addModelCredential: 'Modellberechtigungen hinzufügen', + providerManagedTip: 'Die aktuelle Konfiguration wird vom Anbieter gehostet.', + configLoadBalancing: 'Konfiguration Lastenverteilung', + specifyModelCredentialTip: 'Verwenden Sie ein konfiguriertes Modellzugang.', + }, }, dataSource: { add: 'Eine Datenquelle hinzufügen', diff --git a/web/i18n/de-DE/plugin.ts b/web/i18n/de-DE/plugin.ts index aa136528e2..b2617eae38 100644 --- a/web/i18n/de-DE/plugin.ts +++ b/web/i18n/de-DE/plugin.ts @@ -246,6 +246,9 @@ const translation = { clientInfo: 'Da keine System-Client-Geheimnisse für diesen Tool-Anbieter gefunden wurden, ist eine manuelle Einrichtung erforderlich. Bitte verwenden Sie für redirect_uri', useApiAuthDesc: 'Nachdem die Anmeldeinformationen konfiguriert wurden, können alle Mitglieder des Arbeitsbereichs dieses Tool beim Orchestrieren von Anwendungen verwenden.', authRemoved: 'Die Authentifizierung wurde entfernt.', + unavailable: 'Nicht verfügbar', + credentialUnavailable: 'Anmeldeinformationen derzeit nicht verfügbar. Bitte kontaktieren Sie den Administrator.', + customCredentialUnavailable: 'Benutzerdefinierte Anmeldeinformationen derzeit nicht verfügbar', }, deprecated: 'Abgelehnt', autoUpdate: { diff --git a/web/i18n/en-US/common.ts b/web/i18n/en-US/common.ts index cd585dcf8f..8cbc05d1f4 100644 --- a/web/i18n/en-US/common.ts +++ b/web/i18n/en-US/common.ts @@ -40,6 +40,7 @@ const translation = { deleteApp: 'Delete App', settings: 'Settings', setup: 'Setup', + config: 'Config', getForFree: 'Get for free', reload: 'Reload', ok: 'OK', @@ -466,7 +467,7 @@ const translation = { loadPresets: 'Load Presets', parameters: 'PARAMETERS', loadBalancing: 'Load balancing', - loadBalancingDescription: 'Reduce pressure with multiple sets of credentials.', + loadBalancingDescription: 'Configure multiple credentials for the model and invoke them automatically. ', loadBalancingHeadline: 'Load Balancing', configLoadBalancing: 'Config Load Balancing', modelHasBeenDeprecated: 'This model has been deprecated', @@ -486,6 +487,28 @@ const translation = { discoverMore: 'Discover more in ', emptyProviderTitle: 'Model provider not set up', emptyProviderTip: 'Please install a model provider first.', + auth: { + unAuthorized: 'Unauthorized', + authRemoved: 'Auth removed', + apiKeys: 'API Keys', + addApiKey: 'Add API Key', + addNewModel: 'Add new model', + addCredential: 'Add credential', + addModelCredential: 'Add model credential', + modelCredentials: 'Model credentials', + configModel: 'Config model', + configLoadBalancing: 'Config Load Balancing', + authorizationError: 'Authorization error', + specifyModelCredential: 'Specify model credential', + specifyModelCredentialTip: 'Use a configured model credential.', + providerManaged: 'Provider managed', + providerManagedTip: 'The current configuration is hosted by the provider.', + apiKeyModal: { + title: 'API Key Authorization Configuration', + desc: 'After configuring credentials, all members within the workspace can use this model when orchestrating applications.', + addModel: 'Add model', + }, + }, }, dataSource: { add: 'Add a data source', diff --git a/web/i18n/en-US/dataset-documents.ts b/web/i18n/en-US/dataset-documents.ts index 292270f931..5fdf185350 100644 --- a/web/i18n/en-US/dataset-documents.ts +++ b/web/i18n/en-US/dataset-documents.ts @@ -32,7 +32,6 @@ const translation = { sync: 'Sync', pause: 'Pause', resume: 'Resume', - download: 'Download File', }, index: { enable: 'Enable', diff --git a/web/i18n/en-US/plugin.ts b/web/i18n/en-US/plugin.ts index aa127eaf13..85bbd44bd5 100644 --- a/web/i18n/en-US/plugin.ts +++ b/web/i18n/en-US/plugin.ts @@ -297,6 +297,9 @@ const translation = { authRemoved: 'Auth removed', clientInfo: 'As no system client secrets found for this tool provider, setup it manually is required, for redirect_uri, please use', oauthClient: 'OAuth Client', + credentialUnavailable: 'Credentials currently unavailable. Please contact admin.', + customCredentialUnavailable: 'Custom credentials currently unavailable', + unavailable: 'Unavailable', }, } diff --git a/web/i18n/es-ES/common.ts b/web/i18n/es-ES/common.ts index 9cd576b21b..a77705ecf2 100644 --- a/web/i18n/es-ES/common.ts +++ b/web/i18n/es-ES/common.ts @@ -60,6 +60,7 @@ const translation = { format: 'Formato', deSelectAll: 'Deseleccionar todo', selectAll: 'Seleccionar todo', + config: 'Config', }, errorMsg: { fieldRequired: '{{field}} es requerido', @@ -472,6 +473,28 @@ const translation = { emptyProviderTip: 'Instale primero un proveedor de modelos.', installProvider: 'Instalación de proveedores de modelos', emptyProviderTitle: 'Proveedor de modelos no configurado', + auth: { + apiKeyModal: { + addModel: 'Agregar modelo', + title: 'Configuración de Autorización de Clave API', + desc: 'Después de configurar las credenciales, todos los miembros dentro del espacio de trabajo pueden usar este modelo al orquestar aplicaciones.', + }, + configModel: 'Modelo de configuración', + authorizationError: 'Error de autorización', + specifyModelCredential: 'Especificar las credenciales del modelo', + addModelCredential: 'Agregar credenciales del modelo', + authRemoved: 'Autorización retirada', + unAuthorized: 'No autorizado', + addApiKey: 'Agregar clave API', + apiKeys: 'Claves de API', + providerManagedTip: 'La configuración actual es hospedada por el proveedor.', + providerManaged: 'Proveedor gestionado', + specifyModelCredentialTip: 'Utiliza una credencial de modelo configurada.', + addNewModel: 'Agregar nuevo modelo', + modelCredentials: 'Credenciales del modelo', + addCredential: 'Agregar credencial', + configLoadBalancing: 'Configuración de balanceo de carga', + }, }, dataSource: { add: 'Agregar una fuente de datos', diff --git a/web/i18n/es-ES/plugin.ts b/web/i18n/es-ES/plugin.ts index e937db7a02..9e952a1838 100644 --- a/web/i18n/es-ES/plugin.ts +++ b/web/i18n/es-ES/plugin.ts @@ -246,6 +246,9 @@ const translation = { clientInfo: 'Como no se encontraron secretos de cliente del sistema para este proveedor de herramientas, se requiere configurarlo manualmente. Para redirect_uri, por favor utiliza', oauthClientSettings: 'Configuración del cliente OAuth', default: 'Predeterminado', + customCredentialUnavailable: 'Las credenciales personalizadas no están disponibles actualmente.', + unavailable: 'No disponible', + credentialUnavailable: 'Credenciales actualmente no disponibles. Por favor, contacte al administrador.', }, deprecated: 'Obsoleto', autoUpdate: { diff --git a/web/i18n/fa-IR/common.ts b/web/i18n/fa-IR/common.ts index c195a0a959..5ca5468ebf 100644 --- a/web/i18n/fa-IR/common.ts +++ b/web/i18n/fa-IR/common.ts @@ -60,6 +60,7 @@ const translation = { downloadSuccess: 'دانلود کامل شد.', selectAll: 'انتخاب همه', deSelectAll: 'همه را انتخاب نکنید', + config: 'تنظیمات', }, errorMsg: { fieldRequired: '{{field}} الزامی است', @@ -473,6 +474,28 @@ const translation = { installProvider: 'نصب ارائه دهندگان مدل', discoverMore: 'اطلاعات بیشتر در', emptyProviderTip: 'لطفا ابتدا یک ارائه دهنده مدل نصب کنید.', + auth: { + apiKeyModal: { + title: 'پیکربندی مجوز کلید API', + addModel: 'مدل اضافه کنید', + desc: 'پس از پیکربندی اعتبارنامه‌ها، تمامی اعضای درون فضای کاری می‌توانند از این مدل هنگام نظم‌دهی به برنامه‌ها استفاده کنند.', + }, + authorizationError: 'خطای مجوز', + unAuthorized: 'بدون مجوز', + configModel: 'مدل پیکربندی', + apiKeys: 'کلیدهای API', + addCredential: 'مدرک اضافه کنید', + addNewModel: 'مدل جدید اضافه کن', + addApiKey: 'کلید API را اضافه کنید', + authRemoved: 'منبع حذف شد', + configLoadBalancing: 'پیکربندی بارگذاری متوازن', + specifyModelCredential: 'مدل اعتبارنامه را مشخص کنید', + providerManaged: 'مدیریت شده توسط ارائه‌دهنده', + addModelCredential: 'مدرک مدل را اضافه کنید', + specifyModelCredentialTip: 'از اعتبارنامه مدل پیکربندی شده استفاده کنید.', + providerManagedTip: 'تنظیمات فعلی توسط ارائه‌دهنده میزبانی می‌شود.', + modelCredentials: 'مدل اعتبارنامه', + }, }, dataSource: { add: 'افزودن منبع داده', diff --git a/web/i18n/fa-IR/plugin.ts b/web/i18n/fa-IR/plugin.ts index 1ba3a714a3..2636aa7192 100644 --- a/web/i18n/fa-IR/plugin.ts +++ b/web/i18n/fa-IR/plugin.ts @@ -246,6 +246,9 @@ const translation = { oauthClientSettings: 'تنظیمات کلاینت اوتور', clientInfo: 'از آنجایی که هیچ راز مشتری سیستم برای این ارائه‌دهنده ابزار پیدا نشد، تنظیم دستی آن ضروری است، لطفاً برای redirect_uri از', useApiAuthDesc: 'پس از پیکربندی اعتبارنامه‌ها، تمامی اعضای درون فضای کاری می‌توانند از این ابزار هنگام نظم‌دهی به برنامه‌ها استفاده کنند.', + unavailable: 'در دسترس نیست', + credentialUnavailable: 'دسترسی به مدارک در حال حاضر امکان‌پذیر نیست. لطفاً با مدیر تماس بگیرید.', + customCredentialUnavailable: 'اعتبارنامه‌های سفارشی در حال حاضر در دسترس نیستند', }, deprecated: 'منسوخ شده', autoUpdate: { diff --git a/web/i18n/fr-FR/common.ts b/web/i18n/fr-FR/common.ts index 5bd262fae8..ac6f4c025f 100644 --- a/web/i18n/fr-FR/common.ts +++ b/web/i18n/fr-FR/common.ts @@ -60,6 +60,7 @@ const translation = { downloadSuccess: 'Téléchargement terminé.', deSelectAll: 'Désélectionner tout', selectAll: 'Sélectionner tout', + config: 'Config', }, placeholder: { input: 'Veuillez entrer', @@ -469,6 +470,28 @@ const translation = { installProvider: 'Installer des fournisseurs de modèles', discoverMore: 'Découvrez-en plus dans', emptyProviderTip: 'Veuillez d’abord installer un fournisseur de modèles.', + auth: { + apiKeyModal: { + addModel: 'Ajouter un modèle', + title: 'Configuration de l\'autorisation de clé API', + desc: 'Après avoir configuré les identifiants, tous les membres de l\'espace de travail peuvent utiliser ce modèle lors de l\'orchestration des applications.', + }, + addModelCredential: 'Ajouter des informations d’identification de modèle', + configModel: 'Configurer le modèle', + addNewModel: 'Ajouter un nouveau modèle', + apiKeys: 'Clés API', + providerManaged: 'Fournisseur géré', + configLoadBalancing: 'Configuration de l\'équilibrage de charge', + modelCredentials: 'Informations d\'identification du modèle', + addApiKey: 'Ajouter une clé API', + specifyModelCredential: 'Spécifiez les identifiants du modèle', + authorizationError: 'Erreur d\'autorisation', + authRemoved: 'Autorisation retirée', + addCredential: 'Ajouter un identifiant', + unAuthorized: 'Non autorisé', + specifyModelCredentialTip: 'Utilisez un identifiant de modèle configuré.', + providerManagedTip: 'La configuration actuelle est hébergée par le fournisseur.', + }, }, dataSource: { add: 'Ajouter une source de données', diff --git a/web/i18n/fr-FR/plugin.ts b/web/i18n/fr-FR/plugin.ts index ae6e8c068b..b0ecab7689 100644 --- a/web/i18n/fr-FR/plugin.ts +++ b/web/i18n/fr-FR/plugin.ts @@ -246,6 +246,9 @@ const translation = { setDefault: 'Définir comme par défaut', authorization: 'Autorisation', useApi: 'Utilisez la clé API', + customCredentialUnavailable: 'Les identifiants personnalisés ne sont actuellement pas disponibles.', + credentialUnavailable: 'Les informations d\'identification ne sont actuellement pas disponibles. Veuillez contacter l\'administrateur.', + unavailable: 'Non disponible', }, deprecated: 'Obsolète', autoUpdate: { diff --git a/web/i18n/hi-IN/common.ts b/web/i18n/hi-IN/common.ts index 6b84950b74..eea8168f43 100644 --- a/web/i18n/hi-IN/common.ts +++ b/web/i18n/hi-IN/common.ts @@ -60,6 +60,7 @@ const translation = { format: 'फॉर्मेट', selectAll: 'सभी चुनें', deSelectAll: 'सभी चयन हटाएँ', + config: 'कॉन्फ़िगरेशन', }, errorMsg: { fieldRequired: '{{field}} आवश्यक है', @@ -489,6 +490,28 @@ const translation = { toBeConfigured: 'कॉन्फ़िगर किया जाना है', emptyProviderTitle: 'मॉडल प्रदाता सेट नहीं किया गया', emptyProviderTip: 'कृपया पहले एक मॉडल प्रदाता स्थापित करें।', + auth: { + apiKeyModal: { + addModel: 'मॉडल जोड़ें', + title: 'एपीआई कुंजी प्राधिकरण कॉन्फ़िगरेशन', + desc: 'क्रेडेंशियल्स कॉन्फ़िगर करने के बाद, कार्यक्षेत्र के सभी सदस्यों को एप्लिकेशन को व्यवस्थित करते समय इस मॉडल का उपयोग करने की अनुमति होती है।', + }, + apiKeys: 'एपीआई कुंजी', + addNewModel: 'नया मॉडल जोड़ें', + authorizationError: 'अनु autorización त्रुटि', + unAuthorized: 'अअनधिकारित', + modelCredentials: 'मॉडल क्रेडेंशियल्स', + addCredential: 'क्रेडेंशियल जोड़ें', + addApiKey: 'एपीआई कुंजी जोड़ें', + authRemoved: 'प्राधिकरण हटाया गया', + providerManaged: 'प्रदाता द्वारा प्रबंधित', + configModel: 'कॉन्फ़िग मॉडल', + configLoadBalancing: 'कॉन्फ़िग लोड बैलेंसिंग', + addModelCredential: 'मॉडल क्रेडेंशियल जोड़ें', + specifyModelCredential: 'मॉडल की क्रेडेंशियल निर्दिष्ट करें', + specifyModelCredentialTip: 'कॉन्फ़िगर की गई मॉडल क्रेडेंशियल का उपयोग करें।', + providerManagedTip: 'वर्तमान कॉन्फ़िगरेशन प्रदाता द्वारा होस्ट किया गया है।', + }, }, dataSource: { add: 'डेटा स्रोत जोड़ें', diff --git a/web/i18n/hi-IN/plugin.ts b/web/i18n/hi-IN/plugin.ts index e15b6a85a7..b9ad0cea59 100644 --- a/web/i18n/hi-IN/plugin.ts +++ b/web/i18n/hi-IN/plugin.ts @@ -246,6 +246,9 @@ const translation = { authorization: 'अधिकार', useApiAuthDesc: 'क्रेडेंशियल्स कॉन्फ़िगर करने के बाद, कार्यक्षेत्र के सभी सदस्यों को एप्लिकेशन को व्यवस्थित करते समय इस उपकरण का उपयोग करने की अनुमति होती है।', clientInfo: 'चूंकि इस टूल प्रदाता के लिए कोई सिस्टम क्लाइंट रहस्य नहीं पाए गए हैं, इसलिए इसे मैन्युअल रूप से सेटअप करना आवश्यक है, कृपया redirect_uri का उपयोग करें', + unavailable: 'अप्राप्त', + customCredentialUnavailable: 'कस्टम क्रेडेंशियल वर्तमान में उपलब्ध नहीं हैं', + credentialUnavailable: 'वर्तमान में क्रेडेंशियल्स उपलब्ध नहीं हैं। कृपया प्रशासन से संपर्क करें।', }, deprecated: 'अनुशंसित नहीं', autoUpdate: { diff --git a/web/i18n/it-IT/common.ts b/web/i18n/it-IT/common.ts index 11120f14be..5b8ece7559 100644 --- a/web/i18n/it-IT/common.ts +++ b/web/i18n/it-IT/common.ts @@ -60,6 +60,7 @@ const translation = { format: 'Formato', selectAll: 'Seleziona tutto', deSelectAll: 'Deseleziona tutto', + config: 'Config', }, errorMsg: { fieldRequired: '{{field}} è obbligatorio', @@ -496,6 +497,28 @@ const translation = { emptyProviderTip: 'Si prega di installare prima un fornitore di modelli.', discoverMore: 'Scopri di più in', emptyProviderTitle: 'Provider di modelli non configurato', + auth: { + apiKeyModal: { + addModel: 'Aggiungi modello', + title: 'Configurazione dell\'autorizzazione della chiave API', + desc: 'Dopo aver configurato le credenziali, tutti i membri all\'interno dello spazio di lavoro possono utilizzare questo modello quando orchestrano applicazioni.', + }, + modelCredentials: 'Credenziali del modello', + providerManaged: 'Fornitore gestito', + apiKeys: 'Chiavi API', + authRemoved: 'Autore rimosso', + specifyModelCredential: 'Specifica le credenziali del modello', + addApiKey: 'Aggiungi la chiave API', + addModelCredential: 'Aggiungi le credenziali del modello', + addNewModel: 'Aggiungi un nuovo modello', + providerManagedTip: 'La configurazione attuale è ospitata dal fornitore.', + addCredential: 'Aggiungi credenziali', + authorizationError: 'Errore di autorizzazione', + configLoadBalancing: 'Configurazione del bilanciamento del carico', + unAuthorized: 'Non autorizzato', + specifyModelCredentialTip: 'Usa una credenziale di modello configurato.', + configModel: 'Configura modello', + }, }, dataSource: { add: 'Aggiungi una fonte di dati', diff --git a/web/i18n/it-IT/plugin.ts b/web/i18n/it-IT/plugin.ts index 616e199906..43d135bfe3 100644 --- a/web/i18n/it-IT/plugin.ts +++ b/web/i18n/it-IT/plugin.ts @@ -246,6 +246,9 @@ const translation = { oauthClientSettings: 'Impostazioni del client OAuth', useApiAuth: 'Configurazione dell\'autorizzazione della chiave API', clientInfo: 'Poiché non sono stati trovati segreti client di sistema per questo fornitore di strumenti, è necessario configurarlo manualmente. Per redirect_uri, si prega di utilizzare', + unavailable: 'Non disponibile', + customCredentialUnavailable: 'Le credenziali personalizzate attualmente non sono disponibili', + credentialUnavailable: 'Credenziali attualmente non disponibili. Si prega di contattare l\'amministratore.', }, deprecated: 'Deprecato', autoUpdate: { diff --git a/web/i18n/ja-JP/common.ts b/web/i18n/ja-JP/common.ts index 6159ffdaec..f8e5643b37 100644 --- a/web/i18n/ja-JP/common.ts +++ b/web/i18n/ja-JP/common.ts @@ -66,6 +66,7 @@ const translation = { more: 'もっと', selectAll: 'すべて選択', deSelectAll: 'すべて選択解除', + config: 'コンフィグ', }, errorMsg: { fieldRequired: '{{field}}は必要です', @@ -486,6 +487,28 @@ const translation = { configureTip: 'API キーを設定するか、使用するモデルを追加してください', toBeConfigured: '設定中', emptyProviderTip: '最初にモデルプロバイダーをインストールしてください。', + auth: { + apiKeyModal: { + title: 'APIキー認証設定', + addModel: 'モデルを追加する', + desc: '認証情報を設定した後、ワークスペース内のすべてのメンバーは、アプリケーションを調整する際にこのモデルを使用できます。', + }, + authorizationError: '認証エラー', + apiKeys: 'APIキー', + unAuthorized: '無許可', + configModel: 'モデルを構成する', + addApiKey: 'APIキーを追加してください', + addCredential: '認証情報を追加する', + authRemoved: '認証が削除されました', + modelCredentials: 'モデルの資格情報', + providerManaged: 'プロバイダーが管理しました', + addNewModel: '新しいモデルを追加する', + configLoadBalancing: '構成ロードバランシング', + addModelCredential: 'モデルの資格情報を追加', + providerManagedTip: '現在の設定はプロバイダーによってホストされています。', + specifyModelCredential: 'モデルの資格情報を指定してください', + specifyModelCredentialTip: '構成されたモデルの認証情報を使用してください。', + }, }, dataSource: { add: 'データソースの追加', diff --git a/web/i18n/ja-JP/plugin.ts b/web/i18n/ja-JP/plugin.ts index b202b404b3..d704a346b2 100644 --- a/web/i18n/ja-JP/plugin.ts +++ b/web/i18n/ja-JP/plugin.ts @@ -247,6 +247,9 @@ const translation = { addOAuth: 'OAuthを追加する', useApiAuthDesc: '認証情報を設定した後、ワークスペース内のすべてのメンバーは、アプリケーションをオーケストレーションする際にこのツールを使用できます。', clientInfo: 'このツールプロバイダーにシステムクライアントシークレットが見つからないため、手動で設定する必要があります。redirect_uriには、次を使用してください。', + unavailable: '利用できません', + customCredentialUnavailable: 'カスタム資格情報は現在利用できません', + credentialUnavailable: '現在、資格情報は利用できません。管理者にご連絡ください。', }, autoUpdate: { strategy: { diff --git a/web/i18n/ko-KR/common.ts b/web/i18n/ko-KR/common.ts index e9f44d384b..8b854fe050 100644 --- a/web/i18n/ko-KR/common.ts +++ b/web/i18n/ko-KR/common.ts @@ -60,6 +60,7 @@ const translation = { downloadSuccess: '다운로드 완료.', selectAll: '모두 선택', deSelectAll: '모두 선택 해제', + config: '구성', }, placeholder: { input: '입력해주세요', @@ -464,6 +465,28 @@ const translation = { configureTip: 'api-key 설정 또는 사용할 모델 추가', emptyProviderTip: '먼저 모델 공급자를 설치하십시오.', toBeConfigured: '구성 예정', + auth: { + apiKeyModal: { + addModel: '모델 추가', + title: 'API 키 인증 구성', + desc: '자격증명을 구성한 후에는 작업 공간 내의 모든 구성원이 애플리케이션을 조정할 때 이 모델을 사용할 수 있습니다.', + }, + addApiKey: 'API 키 추가', + apiKeys: 'API 키', + unAuthorized: '무단', + configModel: '구성 모델', + authorizationError: '권한 오류', + configLoadBalancing: '구성 로드 밸런싱', + addNewModel: '새 모델 추가하기', + specifyModelCredentialTip: '구성된 모델 자격 증명을 사용합니다.', + modelCredentials: '모델 자격 증명', + addCredential: '자격 증명을 추가하다', + authRemoved: '인증이 제거되었습니다.', + providerManaged: '제공자가 관리하는', + addModelCredential: '모델 자격 증명 추가', + specifyModelCredential: '모델 자격 증명을 명시하세요.', + providerManagedTip: '현재 구성은 제공업체에 의해 호스팅되고 있습니다.', + }, }, dataSource: { add: '데이터 소스 추가하기', diff --git a/web/i18n/ko-KR/plugin.ts b/web/i18n/ko-KR/plugin.ts index 815a30d3bb..04b6e54b49 100644 --- a/web/i18n/ko-KR/plugin.ts +++ b/web/i18n/ko-KR/plugin.ts @@ -246,6 +246,9 @@ const translation = { useOAuthAuth: 'OAuth 인증 사용하기', useApiAuthDesc: '자격증명을 구성한 후에는 작업 공간 내의 모든 구성원이 애플리케이션을 조정할 때 이 도구를 사용할 수 있습니다.', clientInfo: '이 도구 공급자에 대한 시스템 클라이언트 비밀이 발견되지 않았으므로 수동으로 설정해야 하며, redirect_uri는 다음을 사용하십시오.', + unavailable: '사용할 수 없음', + credentialUnavailable: '현재 자격 증명이 사용 불가능합니다. 관리자에게 문의하십시오.', + customCredentialUnavailable: '현재 사용자 정의 자격 증명이 사용 불가능합니다.', }, deprecated: '사용 중단됨', autoUpdate: { diff --git a/web/i18n/pl-PL/common.ts b/web/i18n/pl-PL/common.ts index 2830b8a4cb..fa98146903 100644 --- a/web/i18n/pl-PL/common.ts +++ b/web/i18n/pl-PL/common.ts @@ -60,6 +60,7 @@ const translation = { downloadSuccess: 'Pobieranie zakończone.', deSelectAll: 'Odznacz wszystkie', selectAll: 'Zaznacz wszystkie', + config: 'Konfiguracja', }, placeholder: { input: 'Proszę wprowadzić', @@ -482,6 +483,28 @@ const translation = { toBeConfigured: 'Do skonfigurowania', configureTip: 'Konfigurowanie klucza interfejsu API lub dodawanie modelu do użycia', emptyProviderTitle: 'Dostawca modelu nie jest skonfigurowany', + auth: { + apiKeyModal: { + addModel: 'Dodaj model', + title: 'Konfiguracja autoryzacji klucza API', + desc: 'Po skonfigurowaniu poświadczeń wszyscy członkowie w przestrzeni roboczej mogą korzystać z tego modelu podczas orkiestracji aplikacji.', + }, + addApiKey: 'Dodaj klucz API', + configModel: 'Skonfiguruj model', + modelCredentials: 'Uprawnienia modelu', + configLoadBalancing: 'Konfiguracja równoważenia obciążenia', + unAuthorized: 'Nieautoryzowany', + specifyModelCredentialTip: 'Użyj skonfigurowanych poświadczeń modelu.', + addCredential: 'Dodaj dane uwierzytelniające', + providerManagedTip: 'Bieżąca konfiguracja jest hostowana przez dostawcę.', + specifyModelCredential: 'Określ dane uwierzytelniające modelu', + authorizationError: 'Błąd autoryzacji', + apiKeys: 'Klucze API', + providerManaged: 'Zarządzane przez dostawcę', + addNewModel: 'Dodaj nowy model', + authRemoved: 'Autoryzacja usunięta', + addModelCredential: 'Dodaj dane uwierzytelniające modelu', + }, }, dataSource: { add: 'Dodaj źródło danych', diff --git a/web/i18n/pl-PL/plugin.ts b/web/i18n/pl-PL/plugin.ts index 5badeafe27..c957ca5641 100644 --- a/web/i18n/pl-PL/plugin.ts +++ b/web/i18n/pl-PL/plugin.ts @@ -246,6 +246,9 @@ const translation = { addOAuth: 'Dodaj OAuth', useApiAuthDesc: 'Po skonfigurowaniu poświadczeń wszyscy członkowie w przestrzeni roboczej mogą korzystać z tego narzędzia podczas orkiestracji aplikacji.', clientInfo: 'Ponieważ nie znaleziono tajemnic klientów systemu dla tego dostawcy narzędzi, wymagane jest ręczne skonfigurowanie, dla redirect_uri proszę użyć', + unavailable: 'Niedostępny', + customCredentialUnavailable: 'Niestandardowe dane logowania są obecnie niedostępne', + credentialUnavailable: 'Kredencje są obecnie niedostępne. Proszę skontaktować się z administratorem.', }, deprecated: 'Nieaktualny', autoUpdate: { diff --git a/web/i18n/pt-BR/common.ts b/web/i18n/pt-BR/common.ts index 3d1b4e002a..b555c2c2b0 100644 --- a/web/i18n/pt-BR/common.ts +++ b/web/i18n/pt-BR/common.ts @@ -60,6 +60,7 @@ const translation = { format: 'Formato', deSelectAll: 'Desmarcar tudo', selectAll: 'Selecionar tudo', + config: 'Configuração', }, placeholder: { input: 'Por favor, insira', @@ -469,6 +470,28 @@ const translation = { configureTip: 'Configure a chave de API ou adicione o modelo a ser usado', emptyProviderTitle: 'Provedor de modelo não configurado', toBeConfigured: 'A ser configurado', + auth: { + apiKeyModal: { + addModel: 'Adicionar modelo', + title: 'Configuração de Autorização de Chave da API', + desc: 'Após configurar as credenciais, todos os membros dentro do espaço de trabalho podem usar este modelo ao orquestrar aplicações.', + }, + addCredential: 'Adicionar credencial', + configModel: 'Configurar modelo', + apiKeys: 'Chaves de API', + unAuthorized: 'Não autorizado', + modelCredentials: 'Credenciais do modelo', + providerManaged: 'Provedor gerenciado', + addApiKey: 'Adicionar chave da API', + authorizationError: 'Erro de autorização', + addNewModel: 'Adicionar novo modelo', + specifyModelCredential: 'Especifique as credenciais do modelo', + providerManagedTip: 'A configuração atual é hospedada pelo provedor.', + authRemoved: 'Autorização removida', + addModelCredential: 'Adicionar credenciais do modelo', + configLoadBalancing: 'Configuração de Balanceamento de Carga', + specifyModelCredentialTip: 'Use uma credencial de modelo configurada.', + }, }, dataSource: { add: 'Adicionar uma fonte de dados', diff --git a/web/i18n/pt-BR/plugin.ts b/web/i18n/pt-BR/plugin.ts index 9b31f5e190..3300ddde56 100644 --- a/web/i18n/pt-BR/plugin.ts +++ b/web/i18n/pt-BR/plugin.ts @@ -246,6 +246,9 @@ const translation = { addOAuth: 'Adicionar OAuth', useApiAuthDesc: 'Após configurar as credenciais, todos os membros dentro do espaço de trabalho podem usar esta ferramenta ao orquestrar aplicações.', clientInfo: 'Como não foram encontrados segredos de cliente do sistema para este provedor de ferramentas, é necessário configurá-lo manualmente. Para redirect_uri, use', + customCredentialUnavailable: 'Credenciais personalizadas atualmente indisponíveis', + unavailable: 'Indisponível', + credentialUnavailable: 'Credenciais atualmente indisponíveis. Por favor, contate o administrador.', }, deprecated: 'Obsoleto', autoUpdate: { diff --git a/web/i18n/ro-RO/common.ts b/web/i18n/ro-RO/common.ts index 62469d9bd1..473a349784 100644 --- a/web/i18n/ro-RO/common.ts +++ b/web/i18n/ro-RO/common.ts @@ -60,6 +60,7 @@ const translation = { more: 'Mai mult', deSelectAll: 'Deselectați tot', selectAll: 'Selectați tot', + config: 'Configurație', }, placeholder: { input: 'Vă rugăm să introduceți', @@ -469,6 +470,28 @@ const translation = { discoverMore: 'Descoperă mai multe în', emptyProviderTip: 'Vă rugăm să instalați mai întâi un furnizor de modele.', toBeConfigured: 'De configurat', + auth: { + apiKeyModal: { + addModel: 'Adăugați model', + title: 'Configurarea autorizării cheii API', + desc: 'După configurarea acreditivelor, toți membrii din spațiul de lucru pot folosi acest model atunci când orchestran aplicații.', + }, + unAuthorized: 'Neautorizat', + addApiKey: 'Adăugați cheia API', + apiKeys: 'Chei API', + addCredential: 'Adăugați acreditive', + configModel: 'Configurați modelul', + addNewModel: 'Adăugați un nou model', + authRemoved: 'Autentificare eliminată', + specifyModelCredential: 'Specificați acreditivele modelului', + providerManaged: 'Gestionat de furnizor', + authorizationError: 'Eroare de autorizare', + configLoadBalancing: 'Configurare echilibrare a încărcării', + addModelCredential: 'Adăugați acreditivele modelului', + providerManagedTip: 'Configurarea curentă este găzduită de furnizor.', + modelCredentials: 'Credențiale model', + specifyModelCredentialTip: 'Utilizați un acreditiv de model configurat.', + }, }, dataSource: { add: 'Adăugați o sursă de date', diff --git a/web/i18n/ro-RO/plugin.ts b/web/i18n/ro-RO/plugin.ts index d65dc829f8..00d4d88eac 100644 --- a/web/i18n/ro-RO/plugin.ts +++ b/web/i18n/ro-RO/plugin.ts @@ -246,6 +246,9 @@ const translation = { setupOAuth: 'Configurați clientul OAuth', useApiAuthDesc: 'După configurarea acreditivelor, toți membrii din spațiul de lucru pot folosi acest instrument atunci când orchestran aplicații.', clientInfo: 'Deoarece nu s-au găsit secretele clientului sistemului pentru acest furnizor de instrumente, este necesară configurarea manuală; pentru redirect_uri, vă rugăm să folosiți', + unavailable: 'Necesar', + customCredentialUnavailable: 'Credentialele personalizate sunt în prezent indisponibile', + credentialUnavailable: 'Credențialele nu sunt disponibile în acest moment. Vă rugăm să contactați administratorul.', }, deprecated: 'Încetat de a mai fi utilizat', autoUpdate: { diff --git a/web/i18n/ru-RU/common.ts b/web/i18n/ru-RU/common.ts index e5b912857f..02bd415dc5 100644 --- a/web/i18n/ru-RU/common.ts +++ b/web/i18n/ru-RU/common.ts @@ -60,6 +60,7 @@ const translation = { downloadSuccess: 'Загрузка завершена.', selectAll: 'Выбрать все', deSelectAll: 'Снять выделение со всех', + config: 'Конфигурация', }, errorMsg: { fieldRequired: '{{field}} обязательно', @@ -473,6 +474,28 @@ const translation = { emptyProviderTip: 'Сначала установите поставщик модели.', discoverMore: 'Узнайте больше в', installProvider: 'Установка поставщиков моделей', + auth: { + apiKeyModal: { + addModel: 'Добавить модель', + title: 'Конфигурация авторизации ключа API', + desc: 'После настройки учетных данных все члены рабочей области могут использовать эту модель при оркестрации приложений.', + }, + authRemoved: 'Удалена аутентификация', + addApiKey: 'Добавьте API-ключ', + addCredential: 'Добавить учетные данные', + apiKeys: 'API ключи', + authorizationError: 'Ошибка авторизации', + modelCredentials: 'Учетные данные модели', + configModel: 'Настройка модели', + providerManaged: 'Управляемый провайдером', + unAuthorized: 'Неавторизованный', + specifyModelCredential: 'Укажите учетные данные модели', + addNewModel: 'Добавить новую модель', + addModelCredential: 'Добавить учетные данные модели', + configLoadBalancing: 'Конфигурация балансировки нагрузки', + providerManagedTip: 'Текущая конфигурация размещена у провайдера.', + specifyModelCredentialTip: 'Используйте конфигурированные учетные данные модели.', + }, }, dataSource: { add: 'Добавить источник данных', diff --git a/web/i18n/ru-RU/plugin.ts b/web/i18n/ru-RU/plugin.ts index 9bbb3c4852..7a6870a236 100644 --- a/web/i18n/ru-RU/plugin.ts +++ b/web/i18n/ru-RU/plugin.ts @@ -246,6 +246,9 @@ const translation = { saveAndAuth: 'Сохранить и авторизовать', useApiAuthDesc: 'После настройки учетных данных все члены рабочей области могут использовать этот инструмент при оркестрации приложений.', clientInfo: 'Поскольку не найдены секреты клиентской системы для этого поставщика инструментов, необходимо настроить его вручную, для redirect_uri, пожалуйста, используйте', + unavailable: 'Недоступно', + customCredentialUnavailable: 'Кастомные учетные данные в настоящее время недоступны', + credentialUnavailable: 'Учетные данные в настоящее время недоступны. Пожалуйста, свяжитесь с администратором.', }, deprecated: 'Устаревший', autoUpdate: { diff --git a/web/i18n/sl-SI/common.ts b/web/i18n/sl-SI/common.ts index ed092c903a..d3acc5f47f 100644 --- a/web/i18n/sl-SI/common.ts +++ b/web/i18n/sl-SI/common.ts @@ -60,6 +60,7 @@ const translation = { format: 'Format', selectAll: 'Izberi vse', deSelectAll: 'Odberi vse', + config: 'Konfiguracija', }, errorMsg: { fieldRequired: '{{field}} je obvezno', @@ -671,6 +672,28 @@ const translation = { emptyProviderTip: 'Najprej namestite ponudnika modelov.', toBeConfigured: 'Za konfiguracijo', configureTip: 'Nastavitev tipke API ali dodajanje modela za uporabo', + auth: { + apiKeyModal: { + addModel: 'Dodaj model', + title: 'Konfiguracija avtorizacije ključev API', + desc: 'Po konfiguraciji poverilnic lahko vsi člani v delovnem prostoru uporabljajo ta model pri usklajevanju aplikacij.', + }, + apiKeys: 'API ključi', + authRemoved: 'Avtor odstranjen', + unAuthorized: 'Neavtorizirano', + addNewModel: 'Dodaj nov model', + addModelCredential: 'Dodajte poverilnice modela', + addCredential: 'Dodaj akreditiv', + modelCredentials: 'Model akreditivi', + configLoadBalancing: 'Nastavitve uravnoteženja obremenitve', + providerManagedTip: 'Trenutna konfiguracija je gostovana pri ponudniku.', + providerManaged: 'Zagotavlja upravljano', + specifyModelCredentialTip: 'Uporabite konfigurirane poverilnice modela.', + specifyModelCredential: 'Določite poverilnice modela', + addApiKey: 'Dodajte API ključ', + configModel: 'Konfiguriraj model', + authorizationError: 'Napaka pri avtorizaciji', + }, }, dataSource: { notion: { diff --git a/web/i18n/sl-SI/plugin.ts b/web/i18n/sl-SI/plugin.ts index dc435f2302..db5c8f1572 100644 --- a/web/i18n/sl-SI/plugin.ts +++ b/web/i18n/sl-SI/plugin.ts @@ -246,6 +246,9 @@ const translation = { oauthClientSettings: 'Nastavitve odjemalca OAuth', clientInfo: 'Ker za tega ponudnika orodij niso bili najdeni klientski skrivnosti sistema, je potrebna ročna nastavitev, za redirect_uri prosimo uporabite', useApiAuthDesc: 'Po konfiguraciji poverilnic lahko vsi člani v delovnem prostoru uporabljajo to orodje pri orkestraciji aplikacij.', + unavailable: 'Nedostopno', + customCredentialUnavailable: 'Trenutno niso na voljo prilagojene prijave.', + credentialUnavailable: 'Trenutno niso na voljo poverilnice. Prosimo, kontaktirajte administratorja.', }, deprecated: 'Zastaran', autoUpdate: { diff --git a/web/i18n/th-TH/common.ts b/web/i18n/th-TH/common.ts index ca26eada6f..b8d01880ff 100644 --- a/web/i18n/th-TH/common.ts +++ b/web/i18n/th-TH/common.ts @@ -60,6 +60,7 @@ const translation = { downloadSuccess: 'ดาวน์โหลดเสร็จสิ้นแล้ว.', selectAll: 'เลือกทั้งหมด', deSelectAll: 'ยกเลิกการเลือกทั้งหมด', + config: 'การตั้งค่า', }, errorMsg: { fieldRequired: '{{field}} เป็นสิ่งจําเป็น', @@ -468,6 +469,28 @@ const translation = { toBeConfigured: 'ต้องกําหนดค่า', installProvider: 'ติดตั้งผู้ให้บริการโมเดล', configureTip: 'ตั้งค่า api-key หรือเพิ่มโมเดลเพื่อใช้', + auth: { + apiKeyModal: { + addModel: 'เพิ่มโมเดล', + title: 'การกำหนดค่าการอนุญาตคีย์ API', + desc: 'หลังจากตั้งค่าข้อมูลประจำตัวแล้ว สมาชิกทุกคนภายในพื้นที่ทำงานสามารถใช้โมเดลนี้เมื่อจัดการแอปพลิเคชันได้', + }, + configModel: 'กำหนดโมเดล', + unAuthorized: 'ไม่ได้รับอนุญาต', + addCredential: 'เพิ่มข้อมูลรับรอง', + addNewModel: 'เพิ่มโมเดลใหม่', + authRemoved: 'ผู้แต่งถูกลบออก', + providerManaged: 'ผู้ให้บริการจัดการ', + addApiKey: 'เพิ่มคีย์ API', + apiKeys: 'คีย์ API', + modelCredentials: 'ข้อมูลรับรองโมเดล', + specifyModelCredential: 'ระบุข้อมูลประจำตัวของโมเดล', + configLoadBalancing: 'การตั้งค่าการโหลดสมดุล', + addModelCredential: 'เพิ่มข้อมูลรับรองโมเดล', + authorizationError: 'ข้อผิดพลาดในการอนุญาต', + specifyModelCredentialTip: 'ใช้ข้อมูลรับรองโมเดลที่กำหนดไว้', + providerManagedTip: 'การกำหนดค่าปัจจุบันถูกโฮสต์โดยผู้ให้บริการ.', + }, }, dataSource: { add: 'เพิ่มแหล่งข้อมูล', diff --git a/web/i18n/th-TH/plugin.ts b/web/i18n/th-TH/plugin.ts index a967280dbd..caf1ccb5e7 100644 --- a/web/i18n/th-TH/plugin.ts +++ b/web/i18n/th-TH/plugin.ts @@ -246,6 +246,9 @@ const translation = { custom: 'ที่กำหนดเอง', useApiAuthDesc: 'หลังจากตั้งค่าข้อมูลประจำตัวแล้ว สมาชิกทุกคนภายในพื้นที่ทำงานสามารถใช้เครื่องมือนี้เมื่อจัดการแอปพลิเคชันได้', clientInfo: 'เนื่องจากไม่พบความลับของลูกค้าสำหรับผู้ให้บริการเครื่องมือนี้ จำเป็นต้องตั้งค่าแบบแมนนวล สำหรับ redirect_uri กรุณาใช้', + unavailable: 'ไม่มีให้บริการ', + customCredentialUnavailable: 'ข้อมูลรับรองที่กำหนดเองขณะนี้ไม่สามารถใช้ได้', + credentialUnavailable: 'ข้อมูลรับรองไม่สามารถใช้งานได้ในขณะนี้ กรุณาติดต่อผู้ดูแลระบบ.', }, deprecated: 'เลิกใช้', autoUpdate: { diff --git a/web/i18n/tr-TR/common.ts b/web/i18n/tr-TR/common.ts index a7b0734799..7dcebecff2 100644 --- a/web/i18n/tr-TR/common.ts +++ b/web/i18n/tr-TR/common.ts @@ -60,6 +60,7 @@ const translation = { downloadFailed: 'İndirme başarısız oldu. Lütfen daha sonra tekrar deneyin.', selectAll: 'Hepsini Seç', deSelectAll: 'Hepsini Seçme', + config: 'Konfigürasyon', }, errorMsg: { fieldRequired: '{{field}} gereklidir', @@ -473,6 +474,28 @@ const translation = { emptyProviderTitle: 'Model sağlayıcı ayarlanmadı', discoverMore: 'Daha fazlasını keşfedin', configureTip: 'Api-key\'i ayarlayın veya kullanmak için model ekleyin', + auth: { + apiKeyModal: { + addModel: 'Model ekle', + title: 'API Anahtar Yetkilendirme Yapılandırması', + desc: 'Kimlik bilgileri yapılandırıldıktan sonra, çalışma alanındaki tüm üyeler bu modeli uygulamaları düzenlerken kullanabilir.', + }, + unAuthorized: 'Yetkisiz', + authRemoved: 'Yazar kaldırıldı', + providerManaged: 'Sağlayıcı yönetimi', + configModel: 'Modeli yapılandır', + apiKeys: 'API Anahtarları', + addApiKey: 'API Anahtarını Ekle', + addCredential: 'Kimlik bilgisi ekle', + addNewModel: 'Yeni model ekle', + providerManagedTip: 'Mevcut yapılandırma sağlayıcı tarafından barındırılmaktadır.', + modelCredentials: 'Model kimlik bilgileri', + specifyModelCredentialTip: 'Yapılandırılmış bir model kimliği kullanın.', + configLoadBalancing: 'Yük Dengeleme Yapılandırması', + addModelCredential: 'Model kimlik bilgisi ekle', + specifyModelCredential: 'Model kimlik bilgilerini belirtin', + authorizationError: 'Yetkilendirme hatası', + }, }, dataSource: { add: 'Bir veri kaynağı ekle', diff --git a/web/i18n/tr-TR/plugin.ts b/web/i18n/tr-TR/plugin.ts index 1856a34c7e..82ddf4bbc4 100644 --- a/web/i18n/tr-TR/plugin.ts +++ b/web/i18n/tr-TR/plugin.ts @@ -246,6 +246,9 @@ const translation = { addApi: 'API Anahtarını Ekle', saveAndAuth: 'Kaydet ve Yetkilendir', clientInfo: 'Bu araç sağlayıcı için sistem istemci gizlilikleri bulunmadığından, manuel olarak ayar yapılması gerekmektedir. redirect_uri için lütfen şu adresi kullanın', + unavailable: 'Kullanılamıyor', + customCredentialUnavailable: 'Özel kimlik bilgileri şu anda mevcut değil.', + credentialUnavailable: 'Kimlik bilgileri şu anda mevcut değil. Lütfen yönetici ile iletişime geçin.', }, deprecated: 'Kaldırılmış', autoUpdate: { diff --git a/web/i18n/uk-UA/common.ts b/web/i18n/uk-UA/common.ts index f8b6e7ba41..550148ad32 100644 --- a/web/i18n/uk-UA/common.ts +++ b/web/i18n/uk-UA/common.ts @@ -60,6 +60,7 @@ const translation = { downloadSuccess: 'Завантаження завершено.', deSelectAll: 'Вимкнути все', selectAll: 'Вибрати все', + config: 'Конфігурація', }, placeholder: { input: 'Будь ласка, введіть текст', @@ -470,6 +471,28 @@ const translation = { emptyProviderTitle: 'Постачальника моделі не налаштовано', configureTip: 'Налаштуйте api-ключ або додайте модель для використання', discoverMore: 'Відкрийте для себе більше в', + auth: { + apiKeyModal: { + addModel: 'Додати модель', + title: 'Конфігурація авторизації API-ключа', + desc: 'Після налаштування облікових даних усі учасники в робочій області можуть використовувати цю модель під час оркестрування програм.', + }, + addApiKey: 'Додайте ключ API', + apiKeys: 'API ключі', + authRemoved: 'Автор видалено', + configModel: 'Конфігураційна модель', + unAuthorized: 'Несанкціоновано', + authorizationError: 'Помилка авторизації', + modelCredentials: 'Модельні облікові дані', + providerManaged: 'Постачальник управляє', + addCredential: 'Додати облікові дані', + specifyModelCredentialTip: 'Використовуйте налаштовані облікові дані моделі.', + specifyModelCredential: 'Вкажіть облікові дані моделі', + addNewModel: 'Додати нову модель', + configLoadBalancing: 'Конфігурація балансування навантаження', + addModelCredential: 'Додати облікові дані моделі', + providerManagedTip: 'Поточна конфігурація розміщується провайдером.', + }, }, dataSource: { add: 'Додати джерело даних', diff --git a/web/i18n/uk-UA/plugin.ts b/web/i18n/uk-UA/plugin.ts index 22b98fbd41..30a0a0df36 100644 --- a/web/i18n/uk-UA/plugin.ts +++ b/web/i18n/uk-UA/plugin.ts @@ -246,6 +246,9 @@ const translation = { oauthClient: 'Клієнт OAuth', clientInfo: 'Оскільки не знайдено жодних секретів клієнта системи для цього постачальника інструментів, потрібно налаштувати його вручну; для redirect_uri, будь ласка, використовуйте', useApiAuthDesc: 'Після налаштування облікових даних усі учасники робочого простору можуть використовувати цей інструмент під час оркестрації додатків.', + unavailable: 'Недоступний', + customCredentialUnavailable: 'Індивідуальні облікові дані наразі недоступні', + credentialUnavailable: 'Облікові дані наразі недоступні. Будь ласка, зверніться до адміністратора.', }, deprecated: 'Застарілий', autoUpdate: { diff --git a/web/i18n/vi-VN/common.ts b/web/i18n/vi-VN/common.ts index 94ed4e9e78..384c4dbf61 100644 --- a/web/i18n/vi-VN/common.ts +++ b/web/i18n/vi-VN/common.ts @@ -60,6 +60,7 @@ const translation = { downloadSuccess: 'Tải xuống đã hoàn thành.', deSelectAll: 'Bỏ chọn tất cả', selectAll: 'Chọn Tất Cả', + config: 'Cấu hình', }, placeholder: { input: 'Vui lòng nhập', @@ -469,6 +470,28 @@ const translation = { emptyProviderTip: 'Vui lòng cài đặt nhà cung cấp mô hình trước.', installProvider: 'Cài đặt nhà cung cấp mô hình', configureTip: 'Thiết lập api-key hoặc thêm mô hình để sử dụng', + auth: { + apiKeyModal: { + addModel: 'Thêm mô hình', + title: 'Cấu hình ủy quyền khóa API', + desc: 'Sau khi cấu hình thông tin xác thực, tất cả các thành viên trong không gian làm việc có thể sử dụng mô hình này khi điều phối các ứng dụng.', + }, + addNewModel: 'Thêm mô hình mới', + addCredential: 'Thêm thông tin đăng nhập', + configLoadBalancing: 'Cấu hình cân bằng tải', + apiKeys: 'Chìa khóa API', + authorizationError: 'Lỗi xác thực', + configModel: 'Cấu hình mô hình', + modelCredentials: 'Chứng chỉ của mô hình', + unAuthorized: 'Không có quyền truy cập', + addApiKey: 'Thêm khóa API', + providerManagedTip: 'Cấu hình hiện tại được lưu trữ bởi nhà cung cấp.', + specifyModelCredential: 'Xác định thông tin xác thực của mô hình', + specifyModelCredentialTip: 'Sử dụng thông tin xác thực của mô hình đã cấu hình.', + addModelCredential: 'Thêm thông tin đăng nhập mô hình', + authRemoved: 'Chính quyền đã loại bỏ', + providerManaged: 'Nhà cung cấp đã được quản lý', + }, }, dataSource: { add: 'Thêm nguồn dữ liệu', diff --git a/web/i18n/vi-VN/plugin.ts b/web/i18n/vi-VN/plugin.ts index c0f3dfac5f..44989cd6aa 100644 --- a/web/i18n/vi-VN/plugin.ts +++ b/web/i18n/vi-VN/plugin.ts @@ -246,6 +246,9 @@ const translation = { setDefault: 'Đặt làm mặc định', useApiAuthDesc: 'Sau khi cấu hình thông tin xác thực, tất cả các thành viên trong không gian làm việc có thể sử dụng công cụ này khi điều phối các ứng dụng.', clientInfo: 'Vì không tìm thấy bí mật khách hàng hệ thống cho nhà cung cấp công cụ này, cần thiết lập thủ công, đối với redirect_uri, vui lòng sử dụng', + unavailable: 'Không có sẵn', + customCredentialUnavailable: 'Thông tin đăng nhập tùy chỉnh hiện không khả dụng', + credentialUnavailable: 'Thông tin đăng nhập hiện không khả dụng. Vui lòng liên hệ với quản trị viên.', }, deprecated: 'Đã bị ngưng sử dụng', autoUpdate: { diff --git a/web/i18n/zh-Hans/common.ts b/web/i18n/zh-Hans/common.ts index 6198363f2f..8d24aa264e 100644 --- a/web/i18n/zh-Hans/common.ts +++ b/web/i18n/zh-Hans/common.ts @@ -40,6 +40,7 @@ const translation = { deleteApp: '删除应用', settings: '设置', setup: '设置', + config: '配置', getForFree: '免费获取', reload: '刷新', ok: '好的', @@ -465,7 +466,7 @@ const translation = { loadPresets: '加载预设', parameters: '参数', loadBalancing: '负载均衡', - loadBalancingDescription: '为了减轻单组凭据的压力,您可以为模型调用配置多组凭据。', + loadBalancingDescription: '为模型配置多组凭据,并自动调用。', loadBalancingHeadline: '负载均衡', configLoadBalancing: '设置负载均衡', modelHasBeenDeprecated: '该模型已废弃', @@ -486,6 +487,28 @@ const translation = { discoverMore: '发现更多就在', emptyProviderTitle: '尚未安装模型供应商', emptyProviderTip: '请安装模型供应商。', + auth: { + unAuthorized: '未授权', + authRemoved: '授权已移除', + apiKeys: 'API 密钥', + addApiKey: '添加 API 密钥', + addNewModel: '添加新模型', + addCredential: '添加凭据', + addModelCredential: '添加模型凭据', + modelCredentials: '模型凭据', + configModel: '配置模型', + configLoadBalancing: '配置负载均衡', + authorizationError: '授权错误', + specifyModelCredential: '指定模型凭据', + specifyModelCredentialTip: '使用已配置的模型凭据。', + providerManaged: '由模型供应商管理', + providerManagedTip: '使用模型供应商提供的单组凭据。', + apiKeyModal: { + title: 'API 密钥授权配置', + desc: '配置凭据后,工作空间中的所有成员都可以在编排应用时使用此模型。', + addModel: '添加模型', + }, + }, }, dataSource: { add: '添加数据源', diff --git a/web/i18n/zh-Hans/plugin.ts b/web/i18n/zh-Hans/plugin.ts index a080a26a8c..e37de6d69f 100644 --- a/web/i18n/zh-Hans/plugin.ts +++ b/web/i18n/zh-Hans/plugin.ts @@ -297,6 +297,9 @@ const translation = { authRemoved: '凭据已移除', clientInfo: '由于未找到此工具提供者的系统客户端密钥,因此需要手动设置,对于 redirect_uri,请使用', oauthClient: 'OAuth 客户端', + credentialUnavailable: '自定义凭据当前不可用,请联系管理员。', + customCredentialUnavailable: '自定义凭据当前不可用', + unavailable: '不可用', }, } diff --git a/web/i18n/zh-Hant/common.ts b/web/i18n/zh-Hant/common.ts index 288cda2316..009bd5ad30 100644 --- a/web/i18n/zh-Hant/common.ts +++ b/web/i18n/zh-Hant/common.ts @@ -60,6 +60,7 @@ const translation = { format: '格式', deSelectAll: '全不選', selectAll: '全選', + config: '配置', }, placeholder: { input: '請輸入', @@ -468,6 +469,27 @@ const translation = { emptyProviderTitle: '未設置模型提供者', configureTip: '設置 api-key 或添加要使用的模型', emptyProviderTip: '請先安裝模型提供者。', + auth: { + apiKeyModal: { + addModel: '添加模型', + title: 'API 金鑰授權配置', + desc: '配置完憑證後,工作區內的所有成員在協調應用程式時都可以使用此模型。', + }, + authRemoved: '授權已被移除', + configModel: '配置模型', + addApiKey: '添加 API 金鑰', + addCredential: '添加憑證', + addModelCredential: '添加模型憑證', + modelCredentials: '模型憑證', + providerManaged: '供應商管理', + addNewModel: '新增模型', + specifyModelCredential: '指定模型憑證', + specifyModelCredentialTip: '使用配置的模型憑證。', + apiKeys: 'API 金鑰', + configLoadBalancing: '配置負載均衡', + unAuthorized: '未經授權', + authorizationError: '授權錯誤', + }, }, dataSource: { add: '新增資料來源', diff --git a/web/i18n/zh-Hant/plugin.ts b/web/i18n/zh-Hant/plugin.ts index 117491fe05..514d7fb4b4 100644 --- a/web/i18n/zh-Hant/plugin.ts +++ b/web/i18n/zh-Hant/plugin.ts @@ -246,6 +246,9 @@ const translation = { useApi: '使用 API 金鑰', clientInfo: '由於未找到此工具提供者的系統客戶端秘密,因此需要手動設置,對於 redirect_uri,請使用', useApiAuthDesc: '配置完憑證後,工作區內的所有成員在協調應用程式時都可以使用此工具。', + unavailable: '無法使用', + customCredentialUnavailable: '自訂憑證目前無法使用', + credentialUnavailable: '凭證目前無法使用。請聯繫管理員。', }, deprecated: '不推薦使用的', autoUpdate: { diff --git a/web/next.config.js b/web/next.config.js index 00793bf26a..6920a47fbf 100644 --- a/web/next.config.js +++ b/web/next.config.js @@ -27,7 +27,10 @@ const nextConfig = { basePath, assetPrefix, webpack: (config, { dev, isServer }) => { - config.plugins.push(codeInspectorPlugin({ bundler: 'webpack' })) + if (dev) { + config.plugins.push(codeInspectorPlugin({ bundler: 'webpack' })) + } + return config }, productionBrowserSourceMaps: false, // enable browser source map generation during the production build diff --git a/web/package.json b/web/package.json index 6623e31971..4d978c107d 100644 --- a/web/package.json +++ b/web/package.json @@ -21,6 +21,7 @@ "scripts": { "dev": "cross-env NODE_OPTIONS='--inspect' next dev", "build": "next build", + "build:docker": "next build && node scripts/optimize-standalone.js", "start": "cp -r .next/static .next/standalone/.next/static && cp -r public .next/standalone/public && cross-env PORT=$npm_config_port HOSTNAME=$npm_config_host node .next/standalone/server.js", "lint": "pnpx oxlint && pnpm eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache", "lint-only-show-error": "pnpx oxlint && pnpm eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache --quiet", diff --git a/web/scripts/README.md b/web/scripts/README.md new file mode 100644 index 0000000000..2c575a244c --- /dev/null +++ b/web/scripts/README.md @@ -0,0 +1,38 @@ +# Production Build Optimization Scripts + +## optimize-standalone.js + +This script removes unnecessary development dependencies from the Next.js standalone build output to reduce the production Docker image size. + +### What it does + +The script specifically targets and removes `jest-worker` packages that are bundled with Next.js but not needed in production. These packages are included because: + +1. Next.js includes jest-worker in its compiled dependencies +1. terser-webpack-plugin (used by Next.js for minification) depends on jest-worker +1. pnpm's dependency resolution creates symlinks to jest-worker in various locations + +### Usage + +The script is automatically run during Docker builds via the `build:docker` npm script: + +```bash +# Docker build (removes jest-worker after build) +pnpm build:docker +``` + +To run the optimization manually: + +```bash +node scripts/optimize-standalone.js +``` + +### What gets removed + +- `node_modules/.pnpm/next@*/node_modules/next/dist/compiled/jest-worker` +- `node_modules/.pnpm/terser-webpack-plugin@*/node_modules/jest-worker` (symlinks) +- `node_modules/.pnpm/jest-worker@*` (actual packages) + +### Impact + +Removing jest-worker saves approximately 36KB per instance from the production image. While this may seem small, it helps ensure production images only contain necessary runtime dependencies. diff --git a/web/scripts/optimize-standalone.js b/web/scripts/optimize-standalone.js new file mode 100644 index 0000000000..f434a5daea --- /dev/null +++ b/web/scripts/optimize-standalone.js @@ -0,0 +1,149 @@ +/** + * Script to optimize Next.js standalone output for production + * Removes unnecessary files like jest-worker that are bundled with Next.js + */ + +const fs = require('fs'); +const path = require('path'); + +console.log('🔧 Optimizing standalone output...'); + +const standaloneDir = path.join(__dirname, '..', '.next', 'standalone'); + +// Check if standalone directory exists +if (!fs.existsSync(standaloneDir)) { + console.error('❌ Standalone directory not found. Please run "next build" first.'); + process.exit(1); +} + +// List of paths to remove (relative to standalone directory) +const pathsToRemove = [ + // Remove jest-worker from Next.js compiled dependencies + 'node_modules/.pnpm/next@*/node_modules/next/dist/compiled/jest-worker', + // Remove jest-worker symlinks from terser-webpack-plugin + 'node_modules/.pnpm/terser-webpack-plugin@*/node_modules/jest-worker', + // Remove actual jest-worker packages (directories only, not symlinks) + 'node_modules/.pnpm/jest-worker@*', +]; + +// Function to safely remove a path +function removePath(basePath, relativePath) { + const fullPath = path.join(basePath, relativePath); + + // Handle wildcard patterns + if (relativePath.includes('*')) { + const parts = relativePath.split('/'); + let currentPath = basePath; + + for (let i = 0; i < parts.length; i++) { + const part = parts[i]; + if (part.includes('*')) { + // Find matching directories + if (fs.existsSync(currentPath)) { + const entries = fs.readdirSync(currentPath); + + // replace '*' with '.*' + const regexPattern = part.replace(/\*/g, '.*'); + + const regex = new RegExp(`^${regexPattern}$`); + + for (const entry of entries) { + if (regex.test(entry)) { + const remainingPath = parts.slice(i + 1).join('/'); + const matchedPath = path.join(currentPath, entry, remainingPath); + + try { + // Use lstatSync to check if path exists (works for both files and symlinks) + const stats = fs.lstatSync(matchedPath); + + if (stats.isSymbolicLink()) { + // Remove symlink + fs.unlinkSync(matchedPath); + console.log(`✅ Removed symlink: ${path.relative(basePath, matchedPath)}`); + } else { + // Remove directory/file + fs.rmSync(matchedPath, { recursive: true, force: true }); + console.log(`✅ Removed: ${path.relative(basePath, matchedPath)}`); + } + } catch (error) { + // Silently ignore ENOENT (path not found) errors + if (error.code !== 'ENOENT') { + console.error(`❌ Failed to remove ${matchedPath}: ${error.message}`); + } + } + } + } + } + return; + } else { + currentPath = path.join(currentPath, part); + } + } + } else { + // Direct path removal + if (fs.existsSync(fullPath)) { + try { + fs.rmSync(fullPath, { recursive: true, force: true }); + console.log(`✅ Removed: ${relativePath}`); + } catch (error) { + console.error(`❌ Failed to remove ${fullPath}: ${error.message}`); + } + } + } +} + +// Remove unnecessary paths +console.log('🗑️ Removing unnecessary files...'); +for (const pathToRemove of pathsToRemove) { + removePath(standaloneDir, pathToRemove); +} + +// Calculate size reduction +console.log('\n📊 Optimization complete!'); + +// Optional: Display the size of remaining jest-related files (if any) +const checkForJest = (dir) => { + const jestFiles = []; + + function walk(currentPath) { + if (!fs.existsSync(currentPath)) return; + + try { + const entries = fs.readdirSync(currentPath); + for (const entry of entries) { + const fullPath = path.join(currentPath, entry); + + try { + const stat = fs.lstatSync(fullPath); // Use lstatSync to handle symlinks + + if (stat.isDirectory() && !stat.isSymbolicLink()) { + // Skip node_modules subdirectories to avoid deep traversal + if (entry === 'node_modules' && currentPath !== standaloneDir) { + continue; + } + walk(fullPath); + } else if (stat.isFile() && entry.includes('jest')) { + jestFiles.push(path.relative(standaloneDir, fullPath)); + } + } catch (err) { + // Skip files that can't be accessed + continue; + } + } + } catch (err) { + // Skip directories that can't be read + return; + } + } + + walk(dir); + return jestFiles; +}; + +const remainingJestFiles = checkForJest(standaloneDir); +if (remainingJestFiles.length > 0) { + console.log('\n⚠️ Warning: Some jest-related files still remain:'); + remainingJestFiles.forEach(file => console.log(` - ${file}`)); +} else { + console.log('\n✨ No jest-related files found in standalone output!'); +} diff --git a/web/service/knowledge/use-document.ts b/web/service/knowledge/use-document.ts index 3d6e322552..e53a5ebde7 100644 --- a/web/service/knowledge/use-document.ts +++ b/web/service/knowledge/use-document.ts @@ -8,9 +8,7 @@ import type { MetadataType, SortType } from '../datasets' import { pauseDocIndexing, resumeDocIndexing } from '../datasets' import type { DocumentDetailResponse, DocumentListResponse, UpdateDocumentBatchParams } from '@/models/datasets' import { DocumentActionType } from '@/models/datasets' -import type { CommonResponse, FileDownloadResponse } from '@/models/common' -// Download document with authentication (sends Authorization header) -import Toast from '@/app/components/base/toast' +import type { CommonResponse } from '@/models/common' const NAME_SPACE = 'knowledge/document' @@ -97,21 +95,6 @@ export const useSyncDocument = () => { }) } -// Download document with authentication (sends Authorization header) -export const useDocumentDownload = () => { - return useMutation({ - mutationFn: async ({ datasetId, documentId }: { datasetId: string; documentId: string }) => { - // The get helper automatically adds the Authorization header from localStorage - return get(`/datasets/${datasetId}/documents/${documentId}/upload-file`) - }, - onError: (error: any) => { - // Show a toast notification if download fails - const message = error?.message || 'Download failed.' - Toast.notify({ type: 'error', message }) - }, - }) -} - export const useSyncWebsite = () => { return useMutation({ mutationFn: ({ datasetId, documentId }: UpdateDocumentBatchParams) => { diff --git a/web/service/use-models.ts b/web/service/use-models.ts index 84122cdd1f..f3336dd03b 100644 --- a/web/service/use-models.ts +++ b/web/service/use-models.ts @@ -1,8 +1,18 @@ -import { get } from './base' +import { + del, + get, + post, + put, +} from './base' import type { + ModelCredential, ModelItem, + ModelLoadBalancingConfig, + ModelTypeEnum, + ProviderCredential, } from '@/app/components/header/account-setting/model-provider-page/declarations' import { + useMutation, useQuery, // useQueryClient, } from '@tanstack/react-query' @@ -15,3 +25,131 @@ export const useModelProviderModelList = (provider: string) => { queryFn: () => get<{ data: ModelItem[] }>(`/workspaces/current/model-providers/${provider}/models`), }) } + +export const useGetProviderCredential = (enabled: boolean, provider: string, credentialId?: string) => { + return useQuery({ + enabled, + queryKey: [NAME_SPACE, 'model-list', provider, credentialId], + queryFn: () => get(`/workspaces/current/model-providers/${provider}/credentials${credentialId ? `?credential_id=${credentialId}` : ''}`), + }) +} + +export const useAddProviderCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: ProviderCredential) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { + body: data, + }), + }) +} + +export const useEditProviderCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: ProviderCredential) => put<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { + body: data, + }), + }) +} + +export const useDeleteProviderCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: { + credential_id: string + }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { + body: data, + }), + }) +} + +export const useActiveProviderCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: { + credential_id: string + model?: string + model_type?: ModelTypeEnum + }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials/switch`, { + body: data, + }), + }) +} + +export const useGetModelCredential = ( + enabled: boolean, + provider: string, + credentialId?: string, + model?: string, + modelType?: string, + configFrom?: string, +) => { + return useQuery({ + enabled, + queryKey: [NAME_SPACE, 'model-list', provider, model, modelType, credentialId], + queryFn: () => get(`/workspaces/current/model-providers/${provider}/models/credentials?model=${model}&model_type=${modelType}&config_from=${configFrom}${credentialId ? `&credential_id=${credentialId}` : ''}`), + staleTime: 0, + gcTime: 0, + }) +} + +export const useAddModelCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: ModelCredential) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { + body: data, + }), + }) +} + +export const useEditModelCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: ModelCredential) => put<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { + body: data, + }), + }) +} + +export const useDeleteModelCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: { + credential_id: string + model?: string + model_type?: ModelTypeEnum + }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { + body: data, + }), + }) +} + +export const useDeleteModel = (provider: string) => { + return useMutation({ + mutationFn: (data: { + model: string + model_type: ModelTypeEnum + }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { + body: data, + }), + }) +} + +export const useActiveModelCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: { + credential_id: string + model?: string + model_type?: ModelTypeEnum + }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials/switch`, { + body: data, + }), + }) +} + +export const useUpdateModelLoadBalancingConfig = (provider: string) => { + return useMutation({ + mutationFn: (data: { + config_from: string + model: string + model_type: ModelTypeEnum + load_balancing: ModelLoadBalancingConfig + credential_id?: string + }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models`, { + body: data, + }), + }) +} diff --git a/web/service/use-plugins-auth.ts b/web/service/use-plugins-auth.ts index 2dc0260647..51992361eb 100644 --- a/web/service/use-plugins-auth.ts +++ b/web/service/use-plugins-auth.ts @@ -19,6 +19,7 @@ export const useGetPluginCredentialInfo = ( enabled: !!url, queryKey: [NAME_SPACE, 'credential-info', url], queryFn: () => get<{ + allow_custom_token?: boolean supported_credential_types: string[] credentials: Credential[] is_oauth_custom_client_enabled: boolean diff --git a/web/themes/dark.css b/web/themes/dark.css index 9b9d467b08..cd1a016f75 100644 --- a/web/themes/dark.css +++ b/web/themes/dark.css @@ -417,6 +417,7 @@ html[data-theme="dark"] { --color-background-overlay-destructive: rgb(240 68 56 / 0.3); --color-background-overlay-backdrop: rgb(24 24 27 / 0.95); --color-background-body-transparent: rgb(29 29 32 / 0); + --color-background-section-burn-inverted: #27272b; --color-shadow-shadow-1: rgb(0 0 0 / 0.05); --color-shadow-shadow-3: rgb(0 0 0 / 0.1); @@ -761,4 +762,4 @@ html[data-theme="dark"] { --color-dify-logo-dify-logo-blue: #e8e8e8; --color-dify-logo-dify-logo-black: #e8e8e8; -} +} \ No newline at end of file diff --git a/web/themes/light.css b/web/themes/light.css index 0a37dd2953..93b76cbfec 100644 --- a/web/themes/light.css +++ b/web/themes/light.css @@ -417,6 +417,7 @@ html[data-theme="light"] { --color-background-overlay-destructive: rgb(240 68 56 / 0.3); --color-background-overlay-backdrop: rgb(242 244 247 / 0.95); --color-background-body-transparent: rgb(242 244 247 / 0); + --color-background-section-burn-inverted: #f2f4f7; --color-shadow-shadow-1: rgb(9 9 11 / 0.03); --color-shadow-shadow-3: rgb(9 9 11 / 0.05); @@ -761,4 +762,4 @@ html[data-theme="light"] { --color-dify-logo-dify-logo-blue: #0033ff; --color-dify-logo-dify-logo-black: #000000; -} +} \ No newline at end of file diff --git a/web/themes/tailwind-theme-var-define.ts b/web/themes/tailwind-theme-var-define.ts index b7b9994262..23d65b4bab 100644 --- a/web/themes/tailwind-theme-var-define.ts +++ b/web/themes/tailwind-theme-var-define.ts @@ -417,6 +417,7 @@ const vars = { 'background-overlay-destructive': 'var(--color-background-overlay-destructive)', 'background-overlay-backdrop': 'var(--color-background-overlay-backdrop)', 'background-body-transparent': 'var(--color-background-body-transparent)', + 'background-section-burn-inverted': 'var(--color-background-section-burn-inverted)', 'shadow-shadow-1': 'var(--color-shadow-shadow-1)', 'shadow-shadow-3': 'var(--color-shadow-shadow-3)',