mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
158da1ce6e
|
|
@ -1,34 +0,0 @@
|
|||
name: Setup UV and Python
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
description: Python version to use and the UV installed with
|
||||
required: true
|
||||
default: '3.12'
|
||||
uv-version:
|
||||
description: UV version to set up
|
||||
required: true
|
||||
default: '0.8.9'
|
||||
uv-lockfile:
|
||||
description: Path to the UV lockfile to restore cache from
|
||||
required: true
|
||||
default: ''
|
||||
enable-cache:
|
||||
required: true
|
||||
default: true
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Set up Python ${{ inputs.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: ${{ inputs.uv-version }}
|
||||
python-version: ${{ inputs.python-version }}
|
||||
enable-cache: ${{ inputs.enable-cache }}
|
||||
cache-dependency-glob: ${{ inputs.uv-lockfile }}
|
||||
|
|
@ -33,10 +33,11 @@ jobs:
|
|||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: ./.github/actions/setup-uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
uv-lockfile: api/uv.lock
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Check UV lockfile
|
||||
run: uv lock --project api --check
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ jobs:
|
|||
- uses: actions/checkout@v4
|
||||
|
||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||
- uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
|
|
|
|||
|
|
@ -25,9 +25,11 @@ jobs:
|
|||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: ./.github/actions/setup-uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
uv-lockfile: api/uv.lock
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api
|
||||
|
|
|
|||
|
|
@ -36,10 +36,11 @@ jobs:
|
|||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
uv-lockfile: api/uv.lock
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
|
|
|||
|
|
@ -39,10 +39,11 @@ jobs:
|
|||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: ./.github/actions/setup-uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
uv-lockfile: api/uv.lock
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Check UV lockfile
|
||||
run: uv lock --project api --check
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ docker compose up -d
|
|||
|
||||
## Contributing
|
||||
|
||||
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_CN.md)。
|
||||
同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。
|
||||
|
||||
> 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline
|
|||
|
||||
## Contributing
|
||||
|
||||
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.
|
||||
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_DE.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.
|
||||
|
||||
> Wir suchen Mitwirkende, die dabei helfen, Dify in weitere Sprachen zu übersetzen – außer Mandarin oder Englisch. Wenn Sie Interesse an einer Mitarbeit haben, lesen Sie bitte die [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) für weitere Informationen und hinterlassen Sie einen Kommentar im `global-users`-Kanal unseres [Discord Community Servers](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @
|
|||
|
||||
## Contribuir
|
||||
|
||||
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_ES.md).
|
||||
Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en eventos y conferencias.
|
||||
|
||||
> Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c).
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart
|
|||
|
||||
## Contribuer
|
||||
|
||||
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_FR.md).
|
||||
Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur les réseaux sociaux et lors d'événements et de conférences.
|
||||
|
||||
> Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c).
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ docker compose up -d
|
|||
|
||||
## 貢献
|
||||
|
||||
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。
|
||||
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_JA.md)を参照してください。
|
||||
同時に、DifyをSNSやイベント、カンファレンスで共有してサポートしていただけると幸いです。
|
||||
|
||||
> Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
|||
|
||||
## 기여
|
||||
|
||||
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.
|
||||
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_KR.md)를 참조하세요.
|
||||
동시에 Dify를 소셜 미디어와 행사 및 컨퍼런스에 공유하여 지원하는 것을 고려해 주시기 바랍니다.
|
||||
|
||||
> 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요.
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by
|
|||
|
||||
## Contribuindo
|
||||
|
||||
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_PT.md).
|
||||
Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em eventos e conferências.
|
||||
|
||||
> Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c).
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
|
|||
|
||||
## Katkıda Bulunma
|
||||
|
||||
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz.
|
||||
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TR.md) bakabilirsiniz.
|
||||
Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda paylaşarak desteklemeyi düşünün.
|
||||
|
||||
> Dify'ı Mandarin veya İngilizce dışındaki dillere çevirmemize yardımcı olacak katkıda bulunanlara ihtiyacımız var. Yardımcı olmakla ilgileniyorsanız, lütfen daha fazla bilgi için [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) dosyasına bakın ve [Discord Topluluk Sunucumuzdaki](https://discord.gg/8Tpq4AcN9c) `global-users` kanalında bize bir yorum bırakın.
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
|
|||
|
||||
## 貢獻
|
||||
|
||||
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TW.md)。
|
||||
同時,也請考慮透過在社群媒體和各種活動與會議上分享 Dify 來支持我們。
|
||||
|
||||
> 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure De
|
|||
|
||||
## Đóng góp
|
||||
|
||||
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi.
|
||||
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_VI.md) của chúng tôi.
|
||||
Đồng thời, vui lòng xem xét hỗ trợ Dify bằng cách chia sẻ nó trên mạng xã hội và tại các sự kiện và hội nghị.
|
||||
|
||||
> Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi.
|
||||
|
|
|
|||
|
|
@ -564,3 +564,7 @@ QUEUE_MONITOR_THRESHOLD=200
|
|||
QUEUE_MONITOR_ALERT_EMAILS=
|
||||
# Monitor interval in minutes, default is 30 minutes
|
||||
QUEUE_MONITOR_INTERVAL=30
|
||||
|
||||
# Swagger UI configuration
|
||||
SWAGGER_UI_ENABLED=true
|
||||
SWAGGER_UI_PATH=/swagger-ui.html
|
||||
|
|
|
|||
|
|
@ -99,14 +99,14 @@ uv run celery -A app.celery beat
|
|||
|
||||
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md)
|
||||
|
||||
```cli
|
||||
uv run --project api pytest # Run all tests
|
||||
uv run --project api pytest tests/unit_tests/ # Unit tests only
|
||||
uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
```bash
|
||||
uv run pytest # Run all tests
|
||||
uv run pytest tests/unit_tests/ # Unit tests only
|
||||
uv run pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --project api mypy . # Type checking
|
||||
../dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run mypy . # Type checking
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Annotated, Literal, Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
|
|
@ -976,6 +976,18 @@ class WorkflowLogConfig(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
class SwaggerUIConfig(BaseSettings):
|
||||
SWAGGER_UI_ENABLED: bool = Field(
|
||||
description="Whether to enable Swagger UI in api module",
|
||||
default=True,
|
||||
)
|
||||
|
||||
SWAGGER_UI_PATH: str = Field(
|
||||
description="Swagger UI page path in api module",
|
||||
default="/swagger-ui.html",
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
|
|
@ -1007,6 +1019,7 @@ class FeatureConfig(
|
|||
WorkspaceConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ from .datasets import (
|
|||
external,
|
||||
hit_testing,
|
||||
metadata,
|
||||
upload_file,
|
||||
website,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ class DraftWorkflowApi(Resource):
|
|||
Get draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -94,6 +95,7 @@ class DraftWorkflowApi(Resource):
|
|||
Sync draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -171,6 +173,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -218,13 +221,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -256,11 +258,10 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
|
|
@ -292,12 +293,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
|
|
@ -329,12 +330,12 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
|
|
@ -366,12 +367,12 @@ class DraftWorkflowRunApi(Resource):
|
|||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
|
|
@ -405,6 +406,9 @@ class WorkflowTaskStopApi(Resource):
|
|||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
|
@ -424,12 +428,12 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
|
|
@ -472,6 +476,9 @@ class PublishedWorkflowApi(Resource):
|
|||
"""
|
||||
Get published workflow
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
|
@ -491,13 +498,12 @@ class PublishedWorkflowApi(Resource):
|
|||
"""
|
||||
Publish workflow
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
|
|
@ -541,6 +547,9 @@ class DefaultBlockConfigsApi(Resource):
|
|||
"""
|
||||
Get default block config
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
|
@ -559,13 +568,12 @@ class DefaultBlockConfigApi(Resource):
|
|||
"""
|
||||
Get default block config
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -595,13 +603,12 @@ class ConvertToWorkflowApi(Resource):
|
|||
Convert expert mode of chatbot app to workflow mode
|
||||
Convert Completion App to Workflow App
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
|
|
@ -645,6 +652,9 @@ class PublishedAllWorkflowApi(Resource):
|
|||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -693,13 +703,12 @@ class WorkflowByIdApi(Resource):
|
|||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
|
|
@ -750,13 +759,12 @@ class WorkflowByIdApi(Resource):
|
|||
"""
|
||||
Delete workflow
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from factories.file_factory import build_from_mapping, build_from_mappings
|
|||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models.account import Account
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
|
@ -135,6 +136,7 @@ def _api_prerequisite(f):
|
|||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@ from controllers.console.app.error import AppNotFoundError
|
|||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> Optional[App]:
|
||||
assert isinstance(current_user, Account)
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
|
|
|
|||
|
|
@ -55,6 +55,12 @@ class EmailOrPasswordMismatchError(BaseHTTPException):
|
|||
code = 400
|
||||
|
||||
|
||||
class AuthenticationFailedError(BaseHTTPException):
|
||||
error_code = "authentication_failed"
|
||||
description = "Invalid email or password."
|
||||
code = 401
|
||||
|
||||
|
||||
class EmailPasswordLoginLimitError(BaseHTTPException):
|
||||
error_code = "email_code_login_limit"
|
||||
description = "Too many incorrect password attempts. Please try again later."
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from configs import dify_config
|
|||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
EmailOrPasswordMismatchError,
|
||||
EmailPasswordLoginLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
|
|
@ -79,7 +79,7 @@ class LoginApi(Resource):
|
|||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args["email"])
|
||||
raise EmailOrPasswordMismatchError()
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
|
|
@ -132,6 +132,7 @@ class ResetPasswordSendEmailApi(Resource):
|
|||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
|
|
|
|||
|
|
@ -553,7 +553,7 @@ class DatasetIndexingStatusApi(Resource):
|
|||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data
|
||||
return data, 200
|
||||
|
||||
|
||||
class DatasetApiKeyApi(Resource):
|
||||
|
|
|
|||
|
|
@ -1,62 +0,0 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
class UploadFileApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""Get upload file."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check upload file
|
||||
if document.data_source_type != "upload_file":
|
||||
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
else:
|
||||
raise ValueError("Upload file id not found in document data source info.")
|
||||
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": url,
|
||||
"download_url": f"{url}&as_attachment=true",
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.timestamp(),
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
|
||||
|
|
@ -43,7 +43,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert current_user is not None
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
|
|
@ -76,6 +76,7 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
|||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import TenantAccountRole
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
|
|
@ -15,10 +15,12 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
assert isinstance(current_user, Account)
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
assert tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -64,10 +66,12 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, config_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
assert tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class MemberInviteEmailApi(Resource):
|
|||
@cloud_edition_billing_resource_check("members")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("emails", type=str, required=True, location="json", action="append")
|
||||
parser.add_argument("emails", type=list, required=True, location="json")
|
||||
parser.add_argument("role", type=str, required=True, default="admin", location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
|
@ -45,12 +46,109 @@ class ModelProviderCredentialApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider)
|
||||
credentials = model_provider_service.get_provider_credential(
|
||||
tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
|
||||
)
|
||||
|
||||
return {"credentials": credentials}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.create_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.update_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
credential_name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ModelProviderCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.switch_active_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -69,7 +167,7 @@ class ModelProviderValidateApi(Resource):
|
|||
error = ""
|
||||
|
||||
try:
|
||||
model_provider_service.provider_credentials_validate(
|
||||
model_provider_service.validate_provider_credentials(
|
||||
tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
|
@ -84,42 +182,6 @@ class ModelProviderValidateApi(Resource):
|
|||
return response
|
||||
|
||||
|
||||
class ModelProviderApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.save_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"]
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ModelProviderIconApi(Resource):
|
||||
"""
|
||||
Get model provider icon
|
||||
|
|
@ -187,8 +249,10 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
|||
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
|
||||
|
||||
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
api.add_resource(
|
||||
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
|
||||
)
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
|
||||
|
||||
api.add_resource(
|
||||
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
|
@ -98,6 +99,7 @@ class ModelProviderModelApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
# To save the model's load balance configs
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -113,22 +115,26 @@ class ModelProviderModelApi(Resource):
|
|||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.get("config_from", "") == "custom-model":
|
||||
if not args.get("credential_id"):
|
||||
raise ValueError("credential_id is required when configuring a custom-model")
|
||||
service = ModelProviderService()
|
||||
service.switch_active_custom_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
||||
if (
|
||||
"load_balancing" in args
|
||||
and args["load_balancing"]
|
||||
and "enabled" in args["load_balancing"]
|
||||
and args["load_balancing"]["enabled"]
|
||||
):
|
||||
if "configs" not in args["load_balancing"]:
|
||||
raise ValueError("invalid load balancing configs")
|
||||
|
||||
if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
|
||||
# save load balancing configs
|
||||
model_load_balancing_service.update_load_balancing_configs(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -136,37 +142,17 @@ class ModelProviderModelApi(Resource):
|
|||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
configs=args["load_balancing"]["configs"],
|
||||
config_from=args.get("config_from", ""),
|
||||
)
|
||||
|
||||
# enable load balancing
|
||||
model_load_balancing_service.enable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
else:
|
||||
# disable load balancing
|
||||
model_load_balancing_service.disable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
|
||||
if args.get("config_from", "") != "predefined-model":
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.save_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
logging.exception(
|
||||
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
|
||||
tenant_id,
|
||||
args.get("model"),
|
||||
args.get("model_type"),
|
||||
)
|
||||
raise ValueError(str(ex))
|
||||
if args.get("load_balancing", {}).get("enabled"):
|
||||
model_load_balancing_service.enable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
else:
|
||||
model_load_balancing_service.disable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -192,7 +178,7 @@ class ModelProviderModelApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credentials(
|
||||
model_provider_service.remove_model(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
|
||||
|
|
@ -216,11 +202,17 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
parser.add_argument("config_from", type=str, required=False, nullable=True, location="args")
|
||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_model_credentials(
|
||||
tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"]
|
||||
current_credential = model_provider_service.get_model_credential(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args.get("credential_id"),
|
||||
)
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
|
@ -228,10 +220,173 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
|
||||
return {
|
||||
"credentials": credentials,
|
||||
"load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
|
||||
}
|
||||
if args.get("config_from", "") == "predefined-model":
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider
|
||||
)
|
||||
else:
|
||||
model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"credentials": current_credential.get("credentials") if current_credential else {},
|
||||
"current_credential_id": current_credential.get("current_credential_id")
|
||||
if current_credential
|
||||
else None,
|
||||
"current_credential_name": current_credential.get("current_credential_name")
|
||||
if current_credential
|
||||
else None,
|
||||
"load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
|
||||
"available_credentials": available_credentials,
|
||||
}
|
||||
)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.create_model_credential(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
credential_name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
logging.exception(
|
||||
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
|
||||
tenant_id,
|
||||
args.get("model"),
|
||||
args.get("model_type"),
|
||||
)
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.update_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
credential_name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.add_model_credential_to_model_list(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ModelProviderModelEnableApi(Resource):
|
||||
|
|
@ -314,7 +469,7 @@ class ModelProviderModelValidateApi(Resource):
|
|||
error = ""
|
||||
|
||||
try:
|
||||
model_provider_service.model_credentials_validate(
|
||||
model_provider_service.validate_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
|
|
@ -379,6 +534,10 @@ api.add_resource(
|
|||
api.add_resource(
|
||||
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelCredentialSwitchApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ api = ExternalApi(
|
|||
doc="/docs", # Enable Swagger UI at /files/docs
|
||||
)
|
||||
|
||||
files_ns = Namespace("files", description="File operations")
|
||||
files_ns = Namespace("files", description="File operations", path="/")
|
||||
|
||||
from . import image_preview, tool_files, upload
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,23 @@
|
|||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="Inner API",
|
||||
description="Internal APIs for enterprise features, billing, and plugin communication",
|
||||
doc="/docs", # Enable Swagger UI at /inner/api/docs
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
|
||||
|
||||
from . import mail
|
||||
from .plugin import plugin
|
||||
from .workspace import workspace
|
||||
|
||||
api.add_namespace(inner_api_ns)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
|
||||
from tasks.mail_inner_task import send_inner_email_task
|
||||
|
||||
|
|
@ -26,13 +26,45 @@ class BaseMail(Resource):
|
|||
return {"message": "success"}, 200
|
||||
|
||||
|
||||
@inner_api_ns.route("/enterprise/mail")
|
||||
class EnterpriseMail(BaseMail):
|
||||
method_decorators = [setup_required, enterprise_inner_api_only]
|
||||
|
||||
@inner_api_ns.doc("send_enterprise_mail")
|
||||
@inner_api_ns.doc(description="Send internal email for enterprise features")
|
||||
@inner_api_ns.expect(_mail_parser)
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
def post(self):
|
||||
"""Send internal email for enterprise features.
|
||||
|
||||
This endpoint allows sending internal emails for enterprise-specific
|
||||
notifications and communications.
|
||||
|
||||
Returns:
|
||||
dict: Success message with status code 200
|
||||
"""
|
||||
return super().post()
|
||||
|
||||
|
||||
@inner_api_ns.route("/billing/mail")
|
||||
class BillingMail(BaseMail):
|
||||
method_decorators = [setup_required, billing_inner_api_only]
|
||||
|
||||
@inner_api_ns.doc("send_billing_mail")
|
||||
@inner_api_ns.doc(description="Send internal email for billing notifications")
|
||||
@inner_api_ns.expect(_mail_parser)
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
def post(self):
|
||||
"""Send internal email for billing notifications.
|
||||
|
||||
api.add_resource(EnterpriseMail, "/enterprise/mail")
|
||||
api.add_resource(BillingMail, "/billing/mail")
|
||||
This endpoint allows sending internal emails for billing-related
|
||||
notifications and alerts.
|
||||
|
||||
Returns:
|
||||
dict: Success message with status code 200
|
||||
"""
|
||||
return super().post()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import Resource
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
|
||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
from core.file.helpers import get_signed_file_url_for_plugin
|
||||
|
|
@ -35,11 +35,21 @@ from models.account import Account, Tenant
|
|||
from models.model import EndUser
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/llm")
|
||||
class PluginInvokeLLMApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
@inner_api_ns.doc("plugin_invoke_llm")
|
||||
@inner_api_ns.doc(description="Invoke LLM models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "LLM invocation successful (streaming response)",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
|
|
@ -48,11 +58,21 @@ class PluginInvokeLLMApi(Resource):
|
|||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/llm/structured-output")
|
||||
class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
|
||||
@inner_api_ns.doc("plugin_invoke_llm_structured")
|
||||
@inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "LLM structured output invocation successful (streaming response)",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLMWithStructuredOutput):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm_with_structured_output(
|
||||
|
|
@ -63,11 +83,21 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
|||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/text-embedding")
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
||||
@inner_api_ns.doc("plugin_invoke_text_embedding")
|
||||
@inner_api_ns.doc(description="Invoke text embedding models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Text embedding successful",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -83,11 +113,17 @@ class PluginInvokeTextEmbeddingApi(Resource):
|
|||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/rerank")
|
||||
class PluginInvokeRerankApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeRerank)
|
||||
@inner_api_ns.doc("plugin_invoke_rerank")
|
||||
@inner_api_ns.doc(description="Invoke rerank models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Rerank successful", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -103,11 +139,21 @@ class PluginInvokeRerankApi(Resource):
|
|||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/tts")
|
||||
class PluginInvokeTTSApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTTS)
|
||||
@inner_api_ns.doc("plugin_invoke_tts")
|
||||
@inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "TTS invocation successful (streaming response)",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_tts(
|
||||
|
|
@ -120,11 +166,17 @@ class PluginInvokeTTSApi(Resource):
|
|||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/speech2text")
|
||||
class PluginInvokeSpeech2TextApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSpeech2Text)
|
||||
@inner_api_ns.doc("plugin_invoke_speech2text")
|
||||
@inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Speech2Text successful", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -140,11 +192,17 @@ class PluginInvokeSpeech2TextApi(Resource):
|
|||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/moderation")
|
||||
class PluginInvokeModerationApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeModeration)
|
||||
@inner_api_ns.doc("plugin_invoke_moderation")
|
||||
@inner_api_ns.doc(description="Invoke moderation models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Moderation successful", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -160,11 +218,21 @@ class PluginInvokeModerationApi(Resource):
|
|||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/tool")
|
||||
class PluginInvokeToolApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
@inner_api_ns.doc("plugin_invoke_tool")
|
||||
@inner_api_ns.doc(description="Invoke tools through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Tool invocation successful (streaming response)",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
|
||||
def generator():
|
||||
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||
|
|
@ -182,11 +250,21 @@ class PluginInvokeToolApi(Resource):
|
|||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/parameter-extractor")
|
||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
||||
@inner_api_ns.doc("plugin_invoke_parameter_extractor")
|
||||
@inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Parameter extraction successful",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -205,11 +283,21 @@ class PluginInvokeParameterExtractorNodeApi(Resource):
|
|||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/question-classifier")
|
||||
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
||||
@inner_api_ns.doc("plugin_invoke_question_classifier")
|
||||
@inner_api_ns.doc(description="Invoke question classifier node through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Question classification successful",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -228,11 +316,21 @@ class PluginInvokeQuestionClassifierNodeApi(Resource):
|
|||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/app")
|
||||
class PluginInvokeAppApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeApp)
|
||||
@inner_api_ns.doc("plugin_invoke_app")
|
||||
@inner_api_ns.doc(description="Invoke application through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "App invocation successful (streaming response)",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
|
||||
response = PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id=payload.app_id,
|
||||
|
|
@ -248,11 +346,21 @@ class PluginInvokeAppApi(Resource):
|
|||
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/encrypt")
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeEncrypt)
|
||||
@inner_api_ns.doc("plugin_invoke_encrypt")
|
||||
@inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Encryption/decryption successful",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt):
|
||||
"""
|
||||
encrypt or decrypt data
|
||||
|
|
@ -265,11 +373,21 @@ class PluginInvokeEncryptApi(Resource):
|
|||
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/summary")
|
||||
class PluginInvokeSummaryApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSummary)
|
||||
@inner_api_ns.doc("plugin_invoke_summary")
|
||||
@inner_api_ns.doc(description="Invoke summary functionality through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Summary generation successful",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary):
|
||||
try:
|
||||
return BaseBackwardsInvocationResponse(
|
||||
|
|
@ -285,40 +403,43 @@ class PluginInvokeSummaryApi(Resource):
|
|||
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||
|
||||
|
||||
@inner_api_ns.route("/upload/file/request")
|
||||
class PluginUploadFileRequestApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||
@inner_api_ns.doc("plugin_upload_file_request")
|
||||
@inner_api_ns.doc(description="Request signed URL for file upload through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Signed URL generated successfully",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
||||
# generate signed url
|
||||
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
@inner_api_ns.route("/fetch/app/info")
|
||||
class PluginFetchAppInfoApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestFetchAppInfo)
|
||||
@inner_api_ns.doc("plugin_fetch_app_info")
|
||||
@inner_api_ns.doc(description="Fetch application information through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "App information retrieved successfully",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo):
|
||||
return BaseBackwardsInvocationResponse(
|
||||
data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id)
|
||||
).model_dump()
|
||||
|
||||
|
||||
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||
api.add_resource(PluginInvokeLLMWithStructuredOutputApi, "/invoke/llm/structured-output")
|
||||
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
|
||||
api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
|
||||
api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
|
||||
api.add_resource(PluginInvokeToolApi, "/invoke/tool")
|
||||
api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
|
||||
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
|
||||
api.add_resource(PluginInvokeAppApi, "/invoke/app")
|
||||
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
|
||||
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
|
||||
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
|
||||
api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -11,9 +11,19 @@ from models.account import Account
|
|||
from services.account_service import TenantService
|
||||
|
||||
|
||||
@inner_api_ns.route("/enterprise/workspace")
|
||||
class EnterpriseWorkspace(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
@inner_api_ns.doc("create_enterprise_workspace")
|
||||
@inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Workspace created successfully",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Owner account not found or service not available",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
|
|
@ -44,9 +54,19 @@ class EnterpriseWorkspace(Resource):
|
|||
}
|
||||
|
||||
|
||||
@inner_api_ns.route("/enterprise/workspace/ownerless")
|
||||
class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
@inner_api_ns.doc("create_enterprise_workspace_ownerless")
|
||||
@inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Workspace created successfully",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
|
|
@ -71,7 +91,3 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
|||
"message": "enterprise workspace created.",
|
||||
"tenant": resp,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(EnterpriseWorkspace, "/enterprise/workspace")
|
||||
api.add_resource(EnterpriseWorkspaceNoOwnerEmail, "/enterprise/workspace/ownerless")
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ api = ExternalApi(
|
|||
doc="/docs", # Enable Swagger UI at /mcp/docs
|
||||
)
|
||||
|
||||
mcp_ns = Namespace("mcp", description="MCP operations")
|
||||
mcp_ns = Namespace("mcp", description="MCP operations", path="/")
|
||||
|
||||
from . import mcp
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ api = ExternalApi(
|
|||
doc="/docs", # Enable Swagger UI at /v1/docs
|
||||
)
|
||||
|
||||
service_api_ns = Namespace("service_api", description="Service operations")
|
||||
service_api_ns = Namespace("service_api", description="Service operations", path="/")
|
||||
|
||||
from . import index
|
||||
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from controllers.service_api.wraps import validate_app_token
|
|||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
|
@ -163,6 +164,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
|
||||
def put(self, app_model: App, annotation_id):
|
||||
"""Update an existing annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -185,6 +187,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
@validate_app_token
|
||||
def delete(self, app_model: App, annotation_id):
|
||||
"""Delete an annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from core.provider_manager import ProviderManager
|
|||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
|
|
@ -213,7 +214,10 @@ class DatasetListApi(DatasetApiResource):
|
|||
)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
configurations = provider_manager.get_configurations(tenant_id=cid)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
|
|
@ -266,6 +270,7 @@ class DatasetListApi(DatasetApiResource):
|
|||
)
|
||||
|
||||
try:
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=args["name"],
|
||||
|
|
@ -319,7 +324,10 @@ class DatasetApi(DatasetApiResource):
|
|||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
configurations = provider_manager.get_configurations(tenant_id=cid)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
|
|
@ -391,6 +399,7 @@ class DatasetApi(DatasetApiResource):
|
|||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
|
|
@ -532,7 +541,10 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def get(self, _, dataset_id):
|
||||
"""Get all knowledge type tags."""
|
||||
tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
tags = TagService.get_tags("knowledge", cid)
|
||||
|
||||
return tags, 200
|
||||
|
||||
|
|
@ -550,6 +562,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -573,6 +586,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -599,6 +613,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
@validate_dataset_token
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
args = tag_delete_parser.parse_args()
|
||||
|
|
@ -622,6 +637,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
|||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -647,6 +663,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
|||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -672,6 +689,8 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
|||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
|
|
|
|||
|
|
@ -7,13 +7,14 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.error import EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -46,7 +47,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
token = None
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
|
||||
|
||||
|
|
@ -131,7 +132,7 @@ class ForgotPasswordResetApi(Resource):
|
|||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,12 @@ from flask_restx import Resource, reqparse
|
|||
from jwt import InvalidTokenError # type: ignore
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
|
||||
from controllers.console.error import AccountBannedError, AccountNotFound
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
InvalidEmailError,
|
||||
)
|
||||
from controllers.console.error import AccountBannedError
|
||||
from controllers.console.wraps import only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from libs.helper import email
|
||||
|
|
@ -29,9 +33,9 @@ class LoginApi(Resource):
|
|||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
raise EmailOrPasswordMismatchError()
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
raise AccountNotFound()
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
|
|
@ -63,7 +67,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
|
||||
account = WebAppAuthService.get_user_through_email(args["email"])
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
|
||||
|
||||
|
|
@ -95,7 +99,7 @@ class EmailCodeLoginApi(Resource):
|
|||
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||
account = WebAppAuthService.get_user_through_email(user_email)
|
||||
if not account:
|
||||
raise AccountNotFound()
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class ModelStatus(Enum):
|
|||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
NO_PERMISSION = "no-permission"
|
||||
DISABLED = "disabled"
|
||||
CREDENTIAL_REMOVED = "credential-removed"
|
||||
|
||||
|
||||
class SimpleModelProviderEntity(BaseModel):
|
||||
|
|
@ -54,6 +55,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
|
|||
|
||||
status: ModelStatus
|
||||
load_balancing_enabled: bool = False
|
||||
has_invalid_load_balancing_configs: bool = False
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -69,6 +69,15 @@ class QuotaConfiguration(BaseModel):
|
|||
restrict_models: list[RestrictModel] = []
|
||||
|
||||
|
||||
class CredentialConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for credential configuration.
|
||||
"""
|
||||
|
||||
credential_id: str
|
||||
credential_name: str
|
||||
|
||||
|
||||
class SystemConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration.
|
||||
|
|
@ -86,6 +95,9 @@ class CustomProviderConfiguration(BaseModel):
|
|||
"""
|
||||
|
||||
credentials: dict
|
||||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
available_credentials: list[CredentialConfiguration] = []
|
||||
|
||||
|
||||
class CustomModelConfiguration(BaseModel):
|
||||
|
|
@ -95,7 +107,10 @@ class CustomModelConfiguration(BaseModel):
|
|||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
credentials: dict | None
|
||||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
|
@ -118,6 +133,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
|
|||
id: str
|
||||
name: str
|
||||
credentials: dict
|
||||
credential_source_type: str | None = None
|
||||
|
||||
|
||||
class ModelSettings(BaseModel):
|
||||
|
|
|
|||
|
|
@ -570,5 +570,5 @@ class LLMGenerator:
|
|||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e)
|
||||
logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name"), exc_info=e)
|
||||
return {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ class ModelProviderFactory:
|
|||
return filtered_credentials
|
||||
|
||||
def get_model_schema(
|
||||
self, *, provider: str, model_type: ModelType, model: str, credentials: dict
|
||||
self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
|
|
|
|||
|
|
@ -100,14 +100,14 @@ class Moderation(Extensible, ABC):
|
|||
if not inputs_config.get("preset_response"):
|
||||
raise ValueError("inputs_config.preset_response is required")
|
||||
|
||||
if len(inputs_config.get("preset_response", 0)) > 100:
|
||||
if len(inputs_config.get("preset_response", "0")) > 100:
|
||||
raise ValueError("inputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
if outputs_config_enabled:
|
||||
if not outputs_config.get("preset_response"):
|
||||
raise ValueError("outputs_config.preset_response is required")
|
||||
|
||||
if len(outputs_config.get("preset_response", 0)) > 100:
|
||||
if len(outputs_config.get("preset_response", "0")) > 100:
|
||||
raise ValueError("outputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
"""
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("")
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
return WorkflowAppGenerator().generate(
|
||||
app_model=app,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from configs import dify_config
|
|||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
CredentialConfiguration,
|
||||
CustomConfiguration,
|
||||
CustomModelConfiguration,
|
||||
CustomProviderConfiguration,
|
||||
|
|
@ -40,7 +41,9 @@ from extensions.ext_redis import redis_client
|
|||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
ProviderCredential,
|
||||
ProviderModel,
|
||||
ProviderModelCredential,
|
||||
ProviderModelSetting,
|
||||
ProviderType,
|
||||
TenantDefaultModel,
|
||||
|
|
@ -488,6 +491,61 @@ class ProviderManager:
|
|||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@staticmethod
|
||||
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
Get provider all credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
|
||||
.order_by(ProviderCredential.created_at.desc())
|
||||
)
|
||||
|
||||
available_credentials = session.scalars(stmt).all()
|
||||
|
||||
return [
|
||||
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_provider_model_available_credentials(
|
||||
tenant_id: str, provider_name: str, model_name: str, model_type: str
|
||||
) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
Get provider custom model all credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
:param model_name: model name
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
ProviderModelCredential.tenant_id == tenant_id,
|
||||
ProviderModelCredential.provider_name == provider_name,
|
||||
ProviderModelCredential.model_name == model_name,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
.order_by(ProviderModelCredential.created_at.desc())
|
||||
)
|
||||
|
||||
available_credentials = session.scalars(stmt).all()
|
||||
|
||||
return [
|
||||
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _init_trial_provider_records(
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||
|
|
@ -590,9 +648,6 @@ class ProviderManager:
|
|||
if provider_record.provider_type == ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
||||
if not provider_record.encrypted_config:
|
||||
continue
|
||||
|
||||
custom_provider_record = provider_record
|
||||
|
||||
# Get custom provider credentials
|
||||
|
|
@ -611,8 +666,8 @@ class ProviderManager:
|
|||
try:
|
||||
# fix origin data
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
raise ValueError("No credentials found")
|
||||
if not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {}
|
||||
elif not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
|
|
@ -637,7 +692,14 @@ class ProviderManager:
|
|||
else:
|
||||
provider_credentials = cached_provider_credentials
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials)
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
|
|
@ -649,8 +711,12 @@ class ProviderManager:
|
|||
# Get custom provider model credentials
|
||||
custom_model_configurations = []
|
||||
for provider_model_record in provider_model_records:
|
||||
if not provider_model_record.encrypted_config:
|
||||
continue
|
||||
available_model_credentials = self.get_provider_model_available_credentials(
|
||||
tenant_id,
|
||||
provider_model_record.provider_name,
|
||||
provider_model_record.model_name,
|
||||
provider_model_record.model_type,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
|
||||
|
|
@ -659,7 +725,7 @@ class ProviderManager:
|
|||
# Get cached provider model credentials
|
||||
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
||||
|
||||
if not cached_provider_model_credentials:
|
||||
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
|
||||
try:
|
||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
|
|
@ -688,6 +754,9 @@ class ProviderManager:
|
|||
model=provider_model_record.model_name,
|
||||
model_type=ModelType.value_of(provider_model_record.model_type),
|
||||
credentials=provider_model_credentials,
|
||||
current_credential_id=provider_model_record.credential_id,
|
||||
current_credential_name=provider_model_record.credential_name,
|
||||
available_model_credentials=available_model_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -899,6 +968,18 @@ class ProviderManager:
|
|||
load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||
and load_balancing_model_config.model_type == provider_model_setting.model_type
|
||||
):
|
||||
if load_balancing_model_config.name == "__delete__":
|
||||
# to calculate current model whether has invalidate lb configs
|
||||
load_balancing_configs.append(
|
||||
ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials={},
|
||||
credential_source_type=load_balancing_model_config.credential_source_type,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not load_balancing_model_config.enabled:
|
||||
continue
|
||||
|
||||
|
|
@ -955,6 +1036,7 @@ class ProviderManager:
|
|||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials=provider_model_credentials,
|
||||
credential_source_type=load_balancing_model_config.credential_source_type,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -188,14 +188,17 @@ class OracleVector(BaseVector):
|
|||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,))
|
||||
return cur.fetchone() is not None
|
||||
conn.close()
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
if not ids:
|
||||
return []
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
|
|
@ -208,14 +211,15 @@ class OracleVector(BaseVector):
|
|||
return
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
|
@ -227,12 +231,20 @@ class OracleVector(BaseVector):
|
|||
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
# Validate and sanitize top_k to prevent SQL injection
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
|
||||
top_k = 4 # Use default if invalid
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
params = [numpy.array(query_vector)]
|
||||
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
||||
placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter)))
|
||||
where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
|
||||
params.extend(document_ids_filter)
|
||||
|
||||
with self._get_connection() as conn:
|
||||
conn.inputtypehandler = self.input_type_handler
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
|
|
@ -241,7 +253,7 @@ class OracleVector(BaseVector):
|
|||
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
|
||||
AS distance FROM {self.table_name}
|
||||
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
|
||||
[numpy.array(query_vector)],
|
||||
params,
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
|
@ -259,7 +271,10 @@ class OracleVector(BaseVector):
|
|||
import nltk # type: ignore
|
||||
from nltk.corpus import stopwords # type: ignore
|
||||
|
||||
# Validate and sanitize top_k to prevent SQL injection
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
|
||||
top_k = 5 # Use default if invalid
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if len(query) > 0:
|
||||
|
|
@ -297,14 +312,21 @@ class OracleVector(BaseVector):
|
|||
with conn.cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
params: dict[str, Any] = {"kk": " ACCUM ".join(entities)}
|
||||
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
placeholders = []
|
||||
for i, doc_id in enumerate(document_ids_filter):
|
||||
param_name = f"doc_id_{i}"
|
||||
placeholders.append(f":{param_name}")
|
||||
params[param_name] = doc_id
|
||||
where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) "
|
||||
|
||||
cur.execute(
|
||||
f"""select meta, text, embedding FROM {self.table_name}
|
||||
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
|
||||
order by score(1) desc fetch first {top_k} rows only""",
|
||||
kk=" ACCUM ".join(entities),
|
||||
params,
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
|
|
|
|||
|
|
@ -83,14 +83,14 @@ class TiDBVector(BaseVector):
|
|||
self._dimension = 1536
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
logger.info("create collection and add texts, collection_name: " + self._collection_name)
|
||||
logger.info("create collection and add texts, collection_name: %s", self._collection_name)
|
||||
self._create_collection(len(embeddings[0]))
|
||||
self.add_texts(texts, embeddings)
|
||||
self._dimension = len(embeddings[0])
|
||||
pass
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
logger.info("_create_collection, collection_name " + self._collection_name)
|
||||
logger.info("_create_collection, collection_name %s", self._collection_name)
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class CacheEmbedding(Embeddings):
|
|||
db.session.rollback()
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
logger.exception("Failed to embed documents: %s")
|
||||
logger.exception("Failed to embed documents")
|
||||
raise ex
|
||||
|
||||
return text_embeddings
|
||||
|
|
|
|||
|
|
@ -39,9 +39,16 @@ class WeightRerankRunner(BaseRerankRunner):
|
|||
unique_documents = []
|
||||
doc_ids = set()
|
||||
for document in documents:
|
||||
if document.metadata is not None and document.metadata["doc_id"] not in doc_ids:
|
||||
if (
|
||||
document.provider == "dify"
|
||||
and document.metadata is not None
|
||||
and document.metadata["doc_id"] not in doc_ids
|
||||
):
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
unique_documents.append(document)
|
||||
else:
|
||||
if document not in unique_documents:
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
|
|
|
|||
|
|
@ -275,35 +275,30 @@ class ApiTool(Tool):
|
|||
if files:
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
if method in {
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
}:
|
||||
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
data=body,
|
||||
files=files,
|
||||
timeout=API_TOOL_DEFAULT_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
_METHOD_MAP = {
|
||||
"get": ssrf_proxy.get,
|
||||
"head": ssrf_proxy.head,
|
||||
"post": ssrf_proxy.post,
|
||||
"put": ssrf_proxy.put,
|
||||
"delete": ssrf_proxy.delete,
|
||||
"patch": ssrf_proxy.patch,
|
||||
}
|
||||
method_lc = method.lower()
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise ValueError(f"Invalid http method {method}")
|
||||
response: httpx.Response = _METHOD_MAP[
|
||||
method_lc
|
||||
]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
data=body,
|
||||
files=files,
|
||||
timeout=API_TOOL_DEFAULT_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return response
|
||||
|
||||
def _convert_body_property_any_of(
|
||||
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
|
||||
|
|
|
|||
|
|
@ -737,7 +737,7 @@ class LLMNode(BaseNode):
|
|||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||
fi
|
||||
|
||||
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
|
||||
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
||||
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
||||
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation}
|
||||
|
||||
elif [[ "${MODE}" == "beat" ]]; then
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ def handle(sender: Message, **kwargs):
|
|||
values=_ProviderUpdateValues(last_used=current_time),
|
||||
description="basic_last_used_update",
|
||||
)
|
||||
logging.info("provider used, tenant_id=%s, provider_name=%s", tenant_id, provider_name)
|
||||
updates_to_perform.append(basic_update)
|
||||
|
||||
# 2. Check if we need to deduct quota (system provider only)
|
||||
|
|
@ -186,6 +187,8 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
|
|||
if not updates_to_perform:
|
||||
return
|
||||
|
||||
updates_to_perform = sorted(updates_to_perform, key=lambda i: (i.filters.tenant_id, i.filters.provider_name))
|
||||
|
||||
# Use SQLAlchemy's context manager for transaction management
|
||||
# This automatically handles commit/rollback
|
||||
with Session(db.engine) as session, session.begin():
|
||||
|
|
@ -212,10 +215,13 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
|
|||
|
||||
# Prepare values dict for SQLAlchemy update
|
||||
update_values = {}
|
||||
if values.last_used is not None:
|
||||
update_values["last_used"] = values.last_used
|
||||
# updateing to `last_used` is removed due to performance reason.
|
||||
# ref: https://github.com/langgenius/dify/issues/24526
|
||||
if values.quota_used is not None:
|
||||
update_values["quota_used"] = values.quota_used
|
||||
# Skip the current update operation if no updates are required.
|
||||
if not update_values:
|
||||
continue
|
||||
|
||||
# Build and execute the update statement
|
||||
stmt = update(Provider).where(*where_conditions).values(**update_values)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ login_manager = flask_login.LoginManager()
|
|||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
# Skip authentication for documentation endpoints
|
||||
if request.path.endswith("/docs") or request.path.endswith("/swagger.json"):
|
||||
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
|
||||
return None
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
|
|
|
|||
|
|
@ -24,8 +24,6 @@ integrate_notion_info_list_fields = {
|
|||
"notion_info": fields.List(fields.Nested(integrate_workspace_fields)),
|
||||
}
|
||||
|
||||
integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String}
|
||||
|
||||
integrate_page_fields = {
|
||||
"page_name": fields.String,
|
||||
"page_id": fields.String,
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ import sys
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app, got_request_exception
|
||||
from flask import Blueprint, Flask, current_app, got_request_exception
|
||||
from flask_restx import Api
|
||||
from werkzeug.exceptions import HTTPException
|
||||
from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
|
||||
|
||||
|
|
@ -106,6 +107,22 @@ def register_external_error_handlers(api: Api) -> None:
|
|||
|
||||
|
||||
class ExternalApi(Api):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
_authorizations = {
|
||||
"Bearer": {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": "Authorization",
|
||||
"description": "Type: Bearer {your-api-key}",
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, app: Blueprint | Flask, *args, **kwargs):
|
||||
kwargs.setdefault("authorizations", self._authorizations)
|
||||
kwargs.setdefault("security", "Bearer")
|
||||
kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED
|
||||
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
|
||||
|
||||
# manual separate call on construction and init_app to ensure configs in kwargs effective
|
||||
super().__init__(app=None, *args, **kwargs) # type: ignore
|
||||
self.init_app(app, **kwargs)
|
||||
register_external_error_handlers(self)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import Union, cast
|
||||
|
||||
from flask import current_app, g, has_request_context, request
|
||||
from flask_login.config import EXEMPT_METHODS # type: ignore
|
||||
|
|
@ -11,7 +11,7 @@ from models.model import EndUser
|
|||
|
||||
#: A proxy for the current user. If no user is logged in, this will be an
|
||||
#: anonymous user
|
||||
current_user: Any = LocalProxy(lambda: _get_user())
|
||||
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
|
||||
|
||||
|
||||
def login_required(func):
|
||||
|
|
@ -52,7 +52,7 @@ def login_required(func):
|
|||
def decorated_view(*args, **kwargs):
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
pass
|
||||
elif not current_user.is_authenticated:
|
||||
elif current_user is not None and not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
|
||||
# flask 1.x compatibility
|
||||
|
|
|
|||
|
|
@ -0,0 +1,177 @@
|
|||
"""Add provider multi credential support
|
||||
|
||||
Revision ID: e8446f481c1e
|
||||
Revises: 8bcc02c9bd07
|
||||
Create Date: 2025-08-09 15:53:54.341341
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column
|
||||
import uuid
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e8446f481c1e'
|
||||
down_revision = 'fa8b0fa6f407'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Create provider_credentials table
|
||||
op.create_table('provider_credentials',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('provider_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('credential_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_config', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='provider_credential_pkey')
|
||||
)
|
||||
|
||||
# Create index for provider_credentials
|
||||
with op.batch_alter_table('provider_credentials', schema=None) as batch_op:
|
||||
batch_op.create_index('provider_credential_tenant_provider_idx', ['tenant_id', 'provider_name'], unique=False)
|
||||
|
||||
# Add credential_id to providers table
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True))
|
||||
|
||||
# Add credential_id to load_balancing_model_configs table
|
||||
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True))
|
||||
|
||||
migrate_existing_providers_data()
|
||||
|
||||
# Remove encrypted_config column from providers table after migration
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('encrypted_config')
|
||||
|
||||
|
||||
def migrate_existing_providers_data():
|
||||
"""migrate providers table data to provider_credentials"""
|
||||
|
||||
# Define table structure for data manipulation
|
||||
providers_table = table('providers',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime()),
|
||||
column('credential_id', models.types.StringUUID()),
|
||||
)
|
||||
|
||||
provider_credential_table = table('provider_credentials',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('credential_name', sa.String()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime())
|
||||
)
|
||||
|
||||
# Get database connection
|
||||
conn = op.get_bind()
|
||||
|
||||
# Query all existing providers data
|
||||
existing_providers = conn.execute(
|
||||
sa.select(providers_table.c.id, providers_table.c.tenant_id,
|
||||
providers_table.c.provider_name, providers_table.c.encrypted_config,
|
||||
providers_table.c.created_at, providers_table.c.updated_at)
|
||||
.where(providers_table.c.encrypted_config.isnot(None))
|
||||
).fetchall()
|
||||
|
||||
# Iterate through each provider and insert into provider_credentials
|
||||
for provider in existing_providers:
|
||||
credential_id = str(uuid.uuid4())
|
||||
if not provider.encrypted_config or provider.encrypted_config.strip() == '':
|
||||
continue
|
||||
|
||||
# Insert into provider_credentials table
|
||||
conn.execute(
|
||||
provider_credential_table.insert().values(
|
||||
id=credential_id,
|
||||
tenant_id=provider.tenant_id,
|
||||
provider_name=provider.provider_name,
|
||||
credential_name='API_KEY1', # Use a default name
|
||||
encrypted_config=provider.encrypted_config,
|
||||
created_at=provider.created_at,
|
||||
updated_at=provider.updated_at
|
||||
)
|
||||
)
|
||||
|
||||
# Update original providers table, set credential_id
|
||||
conn.execute(
|
||||
providers_table.update()
|
||||
.where(providers_table.c.id == provider.id)
|
||||
.values(
|
||||
credential_id=credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
def downgrade():
|
||||
# Re-add encrypted_config column to providers table
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
|
||||
|
||||
# Migrate data back from provider_credentials to providers
|
||||
migrate_data_back_to_providers()
|
||||
|
||||
# Remove credential_id columns
|
||||
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('credential_id')
|
||||
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('credential_id')
|
||||
|
||||
# Drop provider_credentials table
|
||||
op.drop_table('provider_credentials')
|
||||
|
||||
|
||||
def migrate_data_back_to_providers():
|
||||
"""Migrate data back from provider_credentials to providers table for downgrade"""
|
||||
|
||||
# Define table structure for data manipulation
|
||||
providers_table = table('providers',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
column('credential_id', models.types.StringUUID()),
|
||||
)
|
||||
|
||||
provider_credential_table = table('provider_credentials',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('credential_name', sa.String()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
)
|
||||
|
||||
# Get database connection
|
||||
conn = op.get_bind()
|
||||
|
||||
# Query providers that have credential_id
|
||||
providers_with_credentials = conn.execute(
|
||||
sa.select(providers_table.c.id, providers_table.c.credential_id)
|
||||
.where(providers_table.c.credential_id.isnot(None))
|
||||
).fetchall()
|
||||
|
||||
# For each provider, get the credential data and update providers table
|
||||
for provider in providers_with_credentials:
|
||||
credential = conn.execute(
|
||||
sa.select(provider_credential_table.c.encrypted_config)
|
||||
.where(provider_credential_table.c.id == provider.credential_id)
|
||||
).fetchone()
|
||||
|
||||
if credential:
|
||||
# Update providers table with encrypted_config from credential
|
||||
conn.execute(
|
||||
providers_table.update()
|
||||
.where(providers_table.c.id == provider.id)
|
||||
.values(encrypted_config=credential.encrypted_config)
|
||||
)
|
||||
|
|
@ -0,0 +1,186 @@
|
|||
"""Add provider model multi credential support
|
||||
|
||||
Revision ID: 0e154742a5fa
|
||||
Revises: e8446f481c1e
|
||||
Create Date: 2025-08-13 16:05:42.657730
|
||||
|
||||
"""
|
||||
import uuid
|
||||
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '0e154742a5fa'
|
||||
down_revision = 'e8446f481c1e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Create provider_model_credentials table
|
||||
op.create_table('provider_model_credentials',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('provider_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('model_type', sa.String(length=40), nullable=False),
|
||||
sa.Column('credential_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_config', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='provider_model_credential_pkey')
|
||||
)
|
||||
|
||||
# Create index for provider_model_credentials
|
||||
with op.batch_alter_table('provider_model_credentials', schema=None) as batch_op:
|
||||
batch_op.create_index('provider_model_credential_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_name', 'model_type'], unique=False)
|
||||
|
||||
# Add credential_id to provider_models table
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True))
|
||||
|
||||
|
||||
# Add credential_source_type to load_balancing_model_configs table
|
||||
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_source_type', sa.String(length=40), nullable=True))
|
||||
|
||||
# Migrate existing provider_models data
|
||||
migrate_existing_provider_models_data()
|
||||
|
||||
# Remove encrypted_config column from provider_models table after migration
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.drop_column('encrypted_config')
|
||||
|
||||
|
||||
def migrate_existing_provider_models_data():
|
||||
"""migrate provider_models table data to provider_model_credentials"""
|
||||
|
||||
# Define table structure for data manipulation
|
||||
provider_models_table = table('provider_models',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime()),
|
||||
column('credential_id', models.types.StringUUID()),
|
||||
)
|
||||
|
||||
provider_model_credentials_table = table('provider_model_credentials',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('credential_name', sa.String()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime())
|
||||
)
|
||||
|
||||
|
||||
# Get database connection
|
||||
conn = op.get_bind()
|
||||
|
||||
# Query all existing provider_models data with encrypted_config
|
||||
existing_provider_models = conn.execute(
|
||||
sa.select(provider_models_table.c.id, provider_models_table.c.tenant_id,
|
||||
provider_models_table.c.provider_name, provider_models_table.c.model_name,
|
||||
provider_models_table.c.model_type, provider_models_table.c.encrypted_config,
|
||||
provider_models_table.c.created_at, provider_models_table.c.updated_at)
|
||||
.where(provider_models_table.c.encrypted_config.isnot(None))
|
||||
).fetchall()
|
||||
|
||||
# Iterate through each provider_model and insert into provider_model_credentials
|
||||
for provider_model in existing_provider_models:
|
||||
if not provider_model.encrypted_config or provider_model.encrypted_config.strip() == '':
|
||||
continue
|
||||
|
||||
credential_id = str(uuid.uuid4())
|
||||
|
||||
# Insert into provider_model_credentials table
|
||||
conn.execute(
|
||||
provider_model_credentials_table.insert().values(
|
||||
id=credential_id,
|
||||
tenant_id=provider_model.tenant_id,
|
||||
provider_name=provider_model.provider_name,
|
||||
model_name=provider_model.model_name,
|
||||
model_type=provider_model.model_type,
|
||||
credential_name='API_KEY1', # Use a default name
|
||||
encrypted_config=provider_model.encrypted_config,
|
||||
created_at=provider_model.created_at,
|
||||
updated_at=provider_model.updated_at
|
||||
)
|
||||
)
|
||||
|
||||
# Update original provider_models table, set credential_id
|
||||
conn.execute(
|
||||
provider_models_table.update()
|
||||
.where(provider_models_table.c.id == provider_model.id)
|
||||
.values(credential_id=credential_id)
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Re-add encrypted_config column to provider_models table
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
|
||||
|
||||
# Migrate data back from provider_model_credentials to provider_models
|
||||
migrate_data_back_to_provider_models()
|
||||
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.drop_column('credential_id')
|
||||
|
||||
# Remove credential_source_type column from load_balancing_model_configs
|
||||
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('credential_source_type')
|
||||
|
||||
# Drop provider_model_credentials table
|
||||
op.drop_table('provider_model_credentials')
|
||||
|
||||
|
||||
def migrate_data_back_to_provider_models():
|
||||
"""Migrate data back from provider_model_credentials to provider_models table for downgrade"""
|
||||
|
||||
# Define table structure for data manipulation
|
||||
provider_models_table = table('provider_models',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
column('credential_id', models.types.StringUUID()),
|
||||
)
|
||||
|
||||
provider_model_credentials_table = table('provider_model_credentials',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
)
|
||||
|
||||
# Get database connection
|
||||
conn = op.get_bind()
|
||||
|
||||
# Query provider_models that have credential_id
|
||||
provider_models_with_credentials = conn.execute(
|
||||
sa.select(provider_models_table.c.id, provider_models_table.c.credential_id)
|
||||
.where(provider_models_table.c.credential_id.isnot(None))
|
||||
).fetchall()
|
||||
|
||||
# For each provider_model, get the credential data and update provider_models table
|
||||
for provider_model in provider_models_with_credentials:
|
||||
credential = conn.execute(
|
||||
sa.select(provider_model_credentials_table.c.encrypted_config)
|
||||
.where(provider_model_credentials_table.c.id == provider_model.credential_id)
|
||||
).fetchone()
|
||||
|
||||
if credential:
|
||||
# Update provider_models table with encrypted_config from credential
|
||||
conn.execute(
|
||||
provider_models_table.update()
|
||||
.where(provider_models_table.c.id == provider_model.id)
|
||||
.values(encrypted_config=credential.encrypted_config)
|
||||
)
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
|
@ -7,6 +8,7 @@ from sqlalchemy import DateTime, String, func, text
|
|||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
|
|
@ -60,9 +62,9 @@ class Provider(Base):
|
|||
provider_type: Mapped[str] = mapped_column(
|
||||
String(40), nullable=False, server_default=text("'custom'::character varying")
|
||||
)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
quota_type: Mapped[Optional[str]] = mapped_column(
|
||||
String(40), nullable=True, server_default=text("''::character varying")
|
||||
|
|
@ -79,6 +81,21 @@ class Provider(Base):
|
|||
f" provider_type='{self.provider_type}')>"
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
credential = self.credential
|
||||
return credential.credential_name if credential else None
|
||||
|
||||
@property
|
||||
def encrypted_config(self):
|
||||
credential = self.credential
|
||||
return credential.encrypted_config if credential else None
|
||||
|
||||
@property
|
||||
def token_is_set(self):
|
||||
"""
|
||||
|
|
@ -116,11 +133,30 @@ class ProviderModel(Base):
|
|||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return (
|
||||
db.session.query(ProviderModelCredential)
|
||||
.where(ProviderModelCredential.id == self.credential_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
credential = self.credential
|
||||
return credential.credential_name if credential else None
|
||||
|
||||
@property
|
||||
def encrypted_config(self):
|
||||
credential = self.credential
|
||||
return credential.encrypted_config if credential else None
|
||||
|
||||
|
||||
class TenantDefaultModel(Base):
|
||||
__tablename__ = "tenant_default_models"
|
||||
|
|
@ -220,6 +256,56 @@ class LoadBalancingModelConfig(Base):
|
|||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderCredential(Base):
|
||||
"""
|
||||
Provider credential - stores multiple named credentials for each provider
|
||||
"""
|
||||
|
||||
__tablename__ = "provider_credentials"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="provider_credential_pkey"),
|
||||
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderModelCredential(Base):
|
||||
"""
|
||||
Provider model credential - stores multiple named credentials for each provider model
|
||||
"""
|
||||
|
||||
__tablename__ = "provider_model_credentials"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="provider_model_credential_pkey"),
|
||||
sa.Index(
|
||||
"provider_model_credential_tenant_provider_model_idx",
|
||||
"tenant_id",
|
||||
"provider_name",
|
||||
"model_name",
|
||||
"model_type",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ def clean_unused_datasets_task():
|
|||
plan_filter = config["plan_filter"]
|
||||
add_logs = config["add_logs"]
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
# Subquery for counting new documents
|
||||
|
|
@ -86,12 +87,12 @@ def clean_unused_datasets_task():
|
|||
.order_by(Dataset.created_at.desc())
|
||||
)
|
||||
|
||||
datasets = db.paginate(stmt, page=1, per_page=50)
|
||||
datasets = db.paginate(stmt, page=page, per_page=50, error_out=False)
|
||||
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
if datasets.items is None or len(datasets.items) == 0:
|
||||
if datasets is None or datasets.items is None or len(datasets.items) == 0:
|
||||
break
|
||||
|
||||
for dataset in datasets:
|
||||
|
|
@ -150,5 +151,7 @@ def clean_unused_datasets_task():
|
|||
except Exception as e:
|
||||
click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red"))
|
||||
|
||||
page += 1
|
||||
|
||||
end_at = time.perf_counter()
|
||||
click.echo(click.style(f"Cleaned unused dataset from db success latency: {end_at - start_at}", fg="green"))
|
||||
|
|
|
|||
|
|
@ -8,7 +8,12 @@ from core.entities.model_entities import (
|
|||
ModelWithProviderEntity,
|
||||
ProviderModelWithStatusEntity,
|
||||
)
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration
|
||||
from core.entities.provider_entities import (
|
||||
CredentialConfiguration,
|
||||
CustomModelConfiguration,
|
||||
ProviderQuotaType,
|
||||
QuotaConfiguration,
|
||||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
|
|
@ -36,6 +41,10 @@ class CustomConfigurationResponse(BaseModel):
|
|||
"""
|
||||
|
||||
status: CustomConfigurationStatus
|
||||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
available_credentials: Optional[list[CredentialConfiguration]] = None
|
||||
custom_models: Optional[list[CustomModelConfiguration]] = None
|
||||
|
||||
|
||||
class SystemConfigurationResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -3,3 +3,7 @@ from services.errors.base import BaseServiceError
|
|||
|
||||
class AppModelConfigBrokenError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ProviderNotFoundError(BaseServiceError):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from core.model_runtime.model_providers.model_provider_factory import ModelProvi
|
|||
from core.provider_manager import ProviderManager
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import LoadBalancingModelConfig
|
||||
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -185,6 +185,7 @@ class ModelLoadBalancingService:
|
|||
"id": load_balancing_config.id,
|
||||
"name": load_balancing_config.name,
|
||||
"credentials": credentials,
|
||||
"credential_id": load_balancing_config.credential_id,
|
||||
"enabled": load_balancing_config.enabled,
|
||||
"in_cooldown": in_cooldown,
|
||||
"ttl": ttl,
|
||||
|
|
@ -280,7 +281,7 @@ class ModelLoadBalancingService:
|
|||
return inherit_config
|
||||
|
||||
def update_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict]
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str
|
||||
) -> None:
|
||||
"""
|
||||
Update load balancing configurations.
|
||||
|
|
@ -289,6 +290,7 @@ class ModelLoadBalancingService:
|
|||
:param model: model name
|
||||
:param model_type: model type
|
||||
:param configs: load balancing configs
|
||||
:param config_from: predefined-model or custom-model
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
|
|
@ -327,8 +329,37 @@ class ModelLoadBalancingService:
|
|||
config_id = config.get("id")
|
||||
name = config.get("name")
|
||||
credentials = config.get("credentials")
|
||||
credential_id = config.get("credential_id")
|
||||
enabled = config.get("enabled")
|
||||
|
||||
if credential_id:
|
||||
credential_record: ProviderCredential | ProviderModelCredential | None = None
|
||||
if config_from == "predefined-model":
|
||||
credential_record = (
|
||||
db.session.query(ProviderCredential)
|
||||
.filter_by(
|
||||
id=credential_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
credential_record = (
|
||||
db.session.query(ProviderModelCredential)
|
||||
.filter_by(
|
||||
id=credential_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not credential_record:
|
||||
raise ValueError(f"Provider credential with id {credential_id} not found")
|
||||
name = credential_record.credential_name
|
||||
|
||||
if not name:
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
|
|
@ -346,11 +377,6 @@ class ModelLoadBalancingService:
|
|||
|
||||
load_balancing_config = current_load_balancing_configs_dict[config_id]
|
||||
|
||||
# check duplicate name
|
||||
for current_load_balancing_config in current_load_balancing_configs:
|
||||
if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
|
||||
raise ValueError(f"Load balancing config name {name} already exists")
|
||||
|
||||
if credentials:
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
|
|
@ -377,39 +403,48 @@ class ModelLoadBalancingService:
|
|||
self._clear_credentials_cache(tenant_id, config_id)
|
||||
else:
|
||||
# create load balancing config
|
||||
if name == "__inherit__":
|
||||
if name in {"__inherit__", "__delete__"}:
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
# check duplicate name
|
||||
for current_load_balancing_config in current_load_balancing_configs:
|
||||
if current_load_balancing_config.name == name:
|
||||
raise ValueError(f"Load balancing config name {name} already exists")
|
||||
if credential_id:
|
||||
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
|
||||
assert credential_record is not None
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name=credential_record.credential_name,
|
||||
encrypted_config=credential_record.encrypted_config,
|
||||
credential_id=credential_id,
|
||||
credential_source_type=credential_source,
|
||||
)
|
||||
else:
|
||||
if not credentials:
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
|
||||
if not credentials:
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
# validate custom provider config
|
||||
credentials = self._custom_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
validate=False,
|
||||
)
|
||||
|
||||
# validate custom provider config
|
||||
credentials = self._custom_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
validate=False,
|
||||
)
|
||||
|
||||
# create load balancing config
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name=name,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
)
|
||||
# create load balancing config
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name=name,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
)
|
||||
|
||||
db.session.add(load_balancing_model_config)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from services.entities.model_provider_entities import (
|
|||
SimpleProviderEntityResponse,
|
||||
SystemConfigurationResponse,
|
||||
)
|
||||
from services.errors.app_model_config import ProviderNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -28,6 +29,29 @@ class ModelProviderService:
|
|||
def __init__(self) -> None:
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
def _get_provider_configuration(self, tenant_id: str, provider: str):
|
||||
"""
|
||||
Get provider configuration or raise exception if not found.
|
||||
|
||||
Args:
|
||||
tenant_id: Workspace identifier
|
||||
provider: Provider name
|
||||
|
||||
Returns:
|
||||
Provider configuration instance
|
||||
|
||||
Raises:
|
||||
ProviderNotFoundError: If provider doesn't exist
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
|
||||
if not provider_configuration:
|
||||
raise ProviderNotFoundError(f"Provider {provider} does not exist.")
|
||||
|
||||
return provider_configuration
|
||||
|
||||
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]:
|
||||
"""
|
||||
get provider list.
|
||||
|
|
@ -46,6 +70,9 @@ class ModelProviderService:
|
|||
if model_type_entity not in provider_configuration.provider.supported_model_types:
|
||||
continue
|
||||
|
||||
provider_config = provider_configuration.custom_configuration.provider
|
||||
model_config = provider_configuration.custom_configuration.models
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_configuration.provider.provider,
|
||||
|
|
@ -63,7 +90,11 @@ class ModelProviderService:
|
|||
custom_configuration=CustomConfigurationResponse(
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if provider_configuration.is_custom_configuration_available()
|
||||
else CustomConfigurationStatus.NO_CONFIGURE
|
||||
else CustomConfigurationStatus.NO_CONFIGURE,
|
||||
current_credential_id=getattr(provider_config, "current_credential_id", None),
|
||||
current_credential_name=getattr(provider_config, "current_credential_name", None),
|
||||
available_credentials=getattr(provider_config, "available_credentials", []),
|
||||
custom_models=model_config,
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(
|
||||
enabled=provider_configuration.system_configuration.enabled,
|
||||
|
|
@ -82,8 +113,8 @@ class ModelProviderService:
|
|||
For the model provider page,
|
||||
only supports passing in a single provider to query the list of supported models.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider:
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
|
|
@ -95,98 +126,111 @@ class ModelProviderService:
|
|||
for model in provider_configurations.get_models(provider=provider)
|
||||
]
|
||||
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]:
|
||||
def get_provider_credential(
|
||||
self, tenant_id: str, provider: str, credential_id: Optional[str] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
get provider credentials.
|
||||
"""
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
return provider_configuration.get_custom_credentials(obfuscated=True)
|
||||
|
||||
def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
"""
|
||||
validate provider credentials.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider:
|
||||
:param credentials:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
provider_configuration.custom_credentials_validate(credentials)
|
||||
|
||||
def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
"""
|
||||
save custom provider config.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials
|
||||
:param credential_id: credential id, if not provided, return current used credentials
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Add or update custom provider credentials.
|
||||
provider_configuration.add_or_update_custom_credentials(credentials)
|
||||
|
||||
def remove_provider_credentials(self, tenant_id: str, provider: str) -> None:
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
"""
|
||||
remove custom provider config.
|
||||
validate provider credentials before saving.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials dict
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.validate_provider_credentials(credentials)
|
||||
|
||||
def create_provider_credential(
|
||||
self, tenant_id: str, provider: str, credentials: dict, credential_name: str
|
||||
) -> None:
|
||||
"""
|
||||
Create and save new provider credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials dict
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.create_provider_credential(credentials, credential_name)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Remove custom provider credentials.
|
||||
provider_configuration.delete_custom_credentials()
|
||||
|
||||
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]:
|
||||
def update_provider_credential(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
get model credentials.
|
||||
update a saved provider credential (by credential_id).
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials dict
|
||||
:param credential_id: credential id
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.update_provider_credential(
|
||||
credential_id=credential_id,
|
||||
credentials=credentials,
|
||||
credential_name=credential_name,
|
||||
)
|
||||
|
||||
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
|
||||
"""
|
||||
remove a saved provider credential (by credential_id).
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.delete_provider_credential(credential_id=credential_id)
|
||||
|
||||
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
|
||||
"""
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.switch_active_provider_credential(credential_id=credential_id)
|
||||
|
||||
def get_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Retrieve model-specific credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credential_id: Optional credential ID, uses current if not provided
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Get model custom credentials from ProviderModel if exists
|
||||
return provider_configuration.get_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type), model=model, obfuscated=True
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_custom_model_credential( # type: ignore
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def model_credentials_validate(
|
||||
def validate_model_credentials(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -196,49 +240,122 @@ class ModelProviderService:
|
|||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param credentials: model credentials dict
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Validate model credentials
|
||||
provider_configuration.custom_model_credentials_validate(
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
|
||||
)
|
||||
|
||||
def save_model_credentials(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
|
||||
def create_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str
|
||||
) -> None:
|
||||
"""
|
||||
save model credentials.
|
||||
create and save model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param credentials: model credentials dict
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Add or update custom model credentials
|
||||
provider_configuration.add_or_update_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.create_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_name=credential_name,
|
||||
)
|
||||
|
||||
def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
|
||||
def update_model_credential(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
update model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials dict
|
||||
:param credential_id: credential id
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.update_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
credential_name=credential_name,
|
||||
)
|
||||
|
||||
def remove_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
) -> None:
|
||||
"""
|
||||
remove model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.delete_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def switch_active_custom_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
) -> None:
|
||||
"""
|
||||
switch model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.switch_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def add_model_credential_to_model_list(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
) -> None:
|
||||
"""
|
||||
add model credentials to model list.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.add_model_credential_to_model(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
|
||||
"""
|
||||
remove model credentials.
|
||||
|
||||
|
|
@ -248,16 +365,8 @@ class ModelProviderService:
|
|||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Remove custom model credentials
|
||||
provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model)
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model)
|
||||
|
||||
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
|
||||
"""
|
||||
|
|
@ -331,13 +440,7 @@ class ModelProviderService:
|
|||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
|
||||
# fetch credentials
|
||||
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
|
||||
|
|
@ -424,17 +527,11 @@ class ModelProviderService:
|
|||
:param preferred_provider_type: preferred provider type
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
|
||||
# Convert preferred_provider_type to ProviderType
|
||||
preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Switch preferred provider type
|
||||
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
|
||||
|
||||
|
|
@ -448,15 +545,7 @@ class ModelProviderService:
|
|||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Enable model
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
|
|
@ -469,13 +558,5 @@ class ModelProviderService:
|
|||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Enable model
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
|
|
|||
|
|
@ -235,10 +235,17 @@ class TestModelProviderService:
|
|||
mock_provider_entity.provider_credential_schema = None
|
||||
mock_provider_entity.model_credential_schema = None
|
||||
|
||||
mock_custom_config = MagicMock()
|
||||
mock_custom_config.provider.current_credential_id = "credential-123"
|
||||
mock_custom_config.provider.current_credential_name = "test-credential"
|
||||
mock_custom_config.provider.available_credentials = []
|
||||
mock_custom_config.models = []
|
||||
|
||||
mock_provider_config = MagicMock()
|
||||
mock_provider_config.provider = mock_provider_entity
|
||||
mock_provider_config.preferred_provider_type = ProviderType.CUSTOM
|
||||
mock_provider_config.is_custom_configuration_available.return_value = True
|
||||
mock_provider_config.custom_configuration = mock_custom_config
|
||||
mock_provider_config.system_configuration.enabled = True
|
||||
mock_provider_config.system_configuration.current_quota_type = "free"
|
||||
mock_provider_config.system_configuration.quota_configurations = []
|
||||
|
|
@ -314,10 +321,23 @@ class TestModelProviderService:
|
|||
mock_provider_entity_embedding.provider_credential_schema = None
|
||||
mock_provider_entity_embedding.model_credential_schema = None
|
||||
|
||||
mock_custom_config_llm = MagicMock()
|
||||
mock_custom_config_llm.provider.current_credential_id = "credential-123"
|
||||
mock_custom_config_llm.provider.current_credential_name = "test-credential"
|
||||
mock_custom_config_llm.provider.available_credentials = []
|
||||
mock_custom_config_llm.models = []
|
||||
|
||||
mock_custom_config_embedding = MagicMock()
|
||||
mock_custom_config_embedding.provider.current_credential_id = "credential-456"
|
||||
mock_custom_config_embedding.provider.current_credential_name = "test-credential-2"
|
||||
mock_custom_config_embedding.provider.available_credentials = []
|
||||
mock_custom_config_embedding.models = []
|
||||
|
||||
mock_provider_config_llm = MagicMock()
|
||||
mock_provider_config_llm.provider = mock_provider_entity_llm
|
||||
mock_provider_config_llm.preferred_provider_type = ProviderType.CUSTOM
|
||||
mock_provider_config_llm.is_custom_configuration_available.return_value = True
|
||||
mock_provider_config_llm.custom_configuration = mock_custom_config_llm
|
||||
mock_provider_config_llm.system_configuration.enabled = True
|
||||
mock_provider_config_llm.system_configuration.current_quota_type = "free"
|
||||
mock_provider_config_llm.system_configuration.quota_configurations = []
|
||||
|
|
@ -326,6 +346,7 @@ class TestModelProviderService:
|
|||
mock_provider_config_embedding.provider = mock_provider_entity_embedding
|
||||
mock_provider_config_embedding.preferred_provider_type = ProviderType.CUSTOM
|
||||
mock_provider_config_embedding.is_custom_configuration_available.return_value = True
|
||||
mock_provider_config_embedding.custom_configuration = mock_custom_config_embedding
|
||||
mock_provider_config_embedding.system_configuration.enabled = True
|
||||
mock_provider_config_embedding.system_configuration.current_quota_type = "free"
|
||||
mock_provider_config_embedding.system_configuration.quota_configurations = []
|
||||
|
|
@ -497,20 +518,29 @@ class TestModelProviderService:
|
|||
}
|
||||
mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration}
|
||||
|
||||
# Expected result structure
|
||||
expected_credentials = {
|
||||
"credentials": {
|
||||
"api_key": "sk-***123",
|
||||
"base_url": "https://api.openai.com",
|
||||
}
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelProviderService()
|
||||
result = service.get_provider_credentials(tenant.id, "openai")
|
||||
with patch.object(service, "get_provider_credential", return_value=expected_credentials) as mock_method:
|
||||
result = service.get_provider_credential(tenant.id, "openai")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "api_key" in result
|
||||
assert "base_url" in result
|
||||
assert result["api_key"] == "sk-***123"
|
||||
assert result["base_url"] == "https://api.openai.com"
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "credentials" in result
|
||||
assert "api_key" in result["credentials"]
|
||||
assert "base_url" in result["credentials"]
|
||||
assert result["credentials"]["api_key"] == "sk-***123"
|
||||
assert result["credentials"]["base_url"] == "https://api.openai.com"
|
||||
|
||||
# Verify mock interactions
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
mock_provider_configuration.get_custom_credentials.assert_called_once_with(obfuscated=True)
|
||||
# Verify the method was called with correct parameters
|
||||
mock_method.assert_called_once_with(tenant.id, "openai")
|
||||
|
||||
def test_provider_credentials_validate_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
|
|
@ -548,11 +578,11 @@ class TestModelProviderService:
|
|||
# Act: Execute the method under test
|
||||
service = ModelProviderService()
|
||||
# This should not raise an exception
|
||||
service.provider_credentials_validate(tenant.id, "openai", test_credentials)
|
||||
service.validate_provider_credentials(tenant.id, "openai", test_credentials)
|
||||
|
||||
# Assert: Verify mock interactions
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
mock_provider_configuration.custom_credentials_validate.assert_called_once_with(test_credentials)
|
||||
mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials)
|
||||
|
||||
def test_provider_credentials_validate_invalid_provider(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
|
|
@ -581,7 +611,7 @@ class TestModelProviderService:
|
|||
# Act & Assert: Execute the method under test and verify exception
|
||||
service = ModelProviderService()
|
||||
with pytest.raises(ValueError, match="Provider nonexistent does not exist."):
|
||||
service.provider_credentials_validate(tenant.id, "nonexistent", test_credentials)
|
||||
service.validate_provider_credentials(tenant.id, "nonexistent", test_credentials)
|
||||
|
||||
# Verify mock interactions
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
|
|
@ -817,22 +847,29 @@ class TestModelProviderService:
|
|||
}
|
||||
mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration}
|
||||
|
||||
# Expected result structure
|
||||
expected_credentials = {
|
||||
"credentials": {
|
||||
"api_key": "sk-***123",
|
||||
"base_url": "https://api.openai.com",
|
||||
}
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelProviderService()
|
||||
result = service.get_model_credentials(tenant.id, "openai", "llm", "gpt-4")
|
||||
with patch.object(service, "get_model_credential", return_value=expected_credentials) as mock_method:
|
||||
result = service.get_model_credential(tenant.id, "openai", "llm", "gpt-4", None)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "api_key" in result
|
||||
assert "base_url" in result
|
||||
assert result["api_key"] == "sk-***123"
|
||||
assert result["base_url"] == "https://api.openai.com"
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "credentials" in result
|
||||
assert "api_key" in result["credentials"]
|
||||
assert "base_url" in result["credentials"]
|
||||
assert result["credentials"]["api_key"] == "sk-***123"
|
||||
assert result["credentials"]["base_url"] == "https://api.openai.com"
|
||||
|
||||
# Verify mock interactions
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
mock_provider_configuration.get_custom_model_credentials.assert_called_once_with(
|
||||
model_type=ModelType.LLM, model="gpt-4", obfuscated=True
|
||||
)
|
||||
# Verify the method was called with correct parameters
|
||||
mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None)
|
||||
|
||||
def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
|
|
@ -868,11 +905,11 @@ class TestModelProviderService:
|
|||
# Act: Execute the method under test
|
||||
service = ModelProviderService()
|
||||
# This should not raise an exception
|
||||
service.model_credentials_validate(tenant.id, "openai", "llm", "gpt-4", test_credentials)
|
||||
service.validate_model_credentials(tenant.id, "openai", "llm", "gpt-4", test_credentials)
|
||||
|
||||
# Assert: Verify mock interactions
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
mock_provider_configuration.custom_model_credentials_validate.assert_called_once_with(
|
||||
mock_provider_configuration.validate_custom_model_credentials.assert_called_once_with(
|
||||
model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials
|
||||
)
|
||||
|
||||
|
|
@ -909,12 +946,12 @@ class TestModelProviderService:
|
|||
|
||||
# Act: Execute the method under test
|
||||
service = ModelProviderService()
|
||||
service.save_model_credentials(tenant.id, "openai", "llm", "gpt-4", test_credentials)
|
||||
service.create_model_credential(tenant.id, "openai", "llm", "gpt-4", test_credentials, "testname")
|
||||
|
||||
# Assert: Verify mock interactions
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
mock_provider_configuration.add_or_update_custom_model_credentials.assert_called_once_with(
|
||||
model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials
|
||||
mock_provider_configuration.create_custom_model_credential.assert_called_once_with(
|
||||
model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname"
|
||||
)
|
||||
|
||||
def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
|
|
@ -942,17 +979,17 @@ class TestModelProviderService:
|
|||
|
||||
# Create mock provider configuration with remove method
|
||||
mock_provider_configuration = MagicMock()
|
||||
mock_provider_configuration.delete_custom_model_credentials.return_value = None
|
||||
mock_provider_configuration.delete_custom_model_credential.return_value = None
|
||||
mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration}
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelProviderService()
|
||||
service.remove_model_credentials(tenant.id, "openai", "llm", "gpt-4")
|
||||
service.remove_model_credential(tenant.id, "openai", "llm", "gpt-4", "5540007c-b988-46e0-b1c7-9b5fb9f330d6")
|
||||
|
||||
# Assert: Verify mock interactions
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
mock_provider_configuration.delete_custom_model_credentials.assert_called_once_with(
|
||||
model_type=ModelType.LLM, model="gpt-4"
|
||||
mock_provider_configuration.delete_custom_model_credential.assert_called_once_with(
|
||||
model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6"
|
||||
)
|
||||
|
||||
def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,134 @@
|
|||
"""Test authentication security to prevent user enumeration."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
import services.errors.account
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
from controllers.console.auth.login import LoginApi
|
||||
from controllers.console.error import AccountNotFound
|
||||
|
||||
|
||||
class TestAuthenticationSecurity:
|
||||
"""Test authentication endpoints for security against user enumeration."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.app = Flask(__name__)
|
||||
self.api = Api(self.app)
|
||||
self.api.add_resource(LoginApi, "/login")
|
||||
self.client = self.app.test_client()
|
||||
self.app.config["TESTING"] = True
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_allowed(
|
||||
self, mock_get_invitation, mock_send_email, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email sends reset password email when registration is allowed."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = True
|
||||
mock_send_email.return_value = "token123"
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
result = login_api.post()
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "fail", "data": "token123", "code": "account_not_found"}
|
||||
mock_send_email.assert_called_once_with(email="nonexistent@example.com", language="en-US")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_wrong_password_returns_error(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
|
||||
):
|
||||
"""Test that wrong password returns AuthenticationFailedError."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
|
||||
# Assert
|
||||
with pytest.raises(AuthenticationFailedError) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "authentication_failed"
|
||||
assert exc_info.value.description == "Invalid email or password."
|
||||
mock_add_rate_limit.assert_called_once_with("existing@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_disabled(
|
||||
self, mock_get_invitation, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email raises AccountNotFound when registration is disabled."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = False
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
|
||||
# Assert
|
||||
with pytest.raises(AccountNotFound) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "account_not_found"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
|
||||
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
|
||||
"""Test that reset password returns success with token for existing accounts."""
|
||||
# Mock the setup check
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Test with existing account
|
||||
mock_get_user.return_value = MagicMock(email="existing@example.com")
|
||||
mock_send_email.return_value = "token123"
|
||||
|
||||
with self.app.test_request_context("/reset-password", method="POST", json={"email": "existing@example.com"}):
|
||||
from controllers.console.auth.login import ResetPasswordSendEmailApi
|
||||
|
||||
api = ResetPasswordSendEmailApi()
|
||||
result = api.post()
|
||||
|
||||
assert result == {"result": "success", "data": "token123"}
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
ModelSettings,
|
||||
ProviderQuotaType,
|
||||
QuotaConfiguration,
|
||||
QuotaUnit,
|
||||
RestrictModel,
|
||||
SystemConfiguration,
|
||||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity():
|
||||
"""Mock provider entity with basic configuration"""
|
||||
provider_entity = ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
|
||||
description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"),
|
||||
icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
|
||||
icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
|
||||
background="background.png",
|
||||
help=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
)
|
||||
|
||||
return provider_entity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_system_configuration():
|
||||
"""Mock system configuration"""
|
||||
quota_config = QuotaConfiguration(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=1000,
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
restrict_models=[RestrictModel(model="gpt-4", reason="Experimental", model_type=ModelType.LLM)],
|
||||
)
|
||||
|
||||
system_config = SystemConfiguration(
|
||||
enabled=True,
|
||||
credentials={"openai_api_key": "test_key"},
|
||||
quota_configurations=[quota_config],
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
)
|
||||
|
||||
return system_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_custom_configuration():
|
||||
"""Mock custom configuration"""
|
||||
custom_config = CustomConfiguration(provider=None, models=[])
|
||||
return custom_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_configuration(mock_provider_entity, mock_system_configuration, mock_custom_configuration):
|
||||
"""Create a test provider configuration instance"""
|
||||
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
|
||||
return ProviderConfiguration(
|
||||
tenant_id="test_tenant",
|
||||
provider=mock_provider_entity,
|
||||
preferred_provider_type=ProviderType.SYSTEM,
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=mock_system_configuration,
|
||||
custom_configuration=mock_custom_configuration,
|
||||
model_settings=[],
|
||||
)
|
||||
|
||||
|
||||
class TestProviderConfiguration:
|
||||
"""Test cases for ProviderConfiguration class"""
|
||||
|
||||
def test_get_current_credentials_system_provider_success(self, provider_configuration):
|
||||
"""Test successfully getting credentials from system provider"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.SYSTEM
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "test_key"}
|
||||
|
||||
def test_get_current_credentials_model_disabled(self, provider_configuration):
|
||||
"""Test getting credentials when model is disabled"""
|
||||
# Arrange
|
||||
model_setting = ModelSettings(
|
||||
model="gpt-4",
|
||||
model_type=ModelType.LLM,
|
||||
enabled=False,
|
||||
load_balancing_configs=[],
|
||||
has_invalid_load_balancing_configs=False,
|
||||
)
|
||||
provider_configuration.model_settings = [model_setting]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Model gpt-4 is disabled"):
|
||||
provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
def test_get_current_credentials_custom_provider_with_models(self, provider_configuration):
|
||||
"""Test getting credentials from custom provider with model configurations"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.CUSTOM
|
||||
|
||||
mock_model_config = Mock()
|
||||
mock_model_config.model_type = ModelType.LLM
|
||||
mock_model_config.model = "gpt-4"
|
||||
mock_model_config.credentials = {"openai_api_key": "custom_key"}
|
||||
provider_configuration.custom_configuration.models = [mock_model_config]
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "custom_key"}
|
||||
|
||||
def test_get_system_configuration_status_active(self, provider_configuration):
|
||||
"""Test getting active system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = True
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.ACTIVE
|
||||
|
||||
def test_get_system_configuration_status_unsupported(self, provider_configuration):
|
||||
"""Test getting unsupported system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = False
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.UNSUPPORTED
|
||||
|
||||
def test_get_system_configuration_status_quota_exceeded(self, provider_configuration):
|
||||
"""Test getting quota exceeded system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = True
|
||||
quota_config = provider_configuration.system_configuration.quota_configurations[0]
|
||||
quota_config.is_valid = False
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
|
||||
def test_is_custom_configuration_available_with_provider(self, provider_configuration):
|
||||
"""Test custom configuration availability with provider credentials"""
|
||||
# Arrange
|
||||
mock_provider = Mock()
|
||||
mock_provider.available_credentials = ["openai_api_key"]
|
||||
provider_configuration.custom_configuration.provider = mock_provider
|
||||
provider_configuration.custom_configuration.models = []
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_is_custom_configuration_available_with_models(self, provider_configuration):
|
||||
"""Test custom configuration availability with model configurations"""
|
||||
# Arrange
|
||||
provider_configuration.custom_configuration.provider = None
|
||||
provider_configuration.custom_configuration.models = [Mock()]
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_is_custom_configuration_available_false(self, provider_configuration):
|
||||
"""Test custom configuration not available"""
|
||||
# Arrange
|
||||
provider_configuration.custom_configuration.provider = None
|
||||
provider_configuration.custom_configuration.models = []
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_provider_record_found(self, mock_session, provider_configuration):
|
||||
"""Test getting provider record successfully"""
|
||||
# Arrange
|
||||
mock_provider = Mock(spec=Provider)
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_provider
|
||||
|
||||
# Act
|
||||
result = provider_configuration._get_provider_record(mock_session_instance)
|
||||
|
||||
# Assert
|
||||
assert result == mock_provider
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_provider_record_not_found(self, mock_session, provider_configuration):
|
||||
"""Test getting provider record when not found"""
|
||||
# Arrange
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act
|
||||
result = provider_configuration._get_provider_record(mock_session_instance)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_init_with_customizable_model_only(
|
||||
self, mock_provider_entity, mock_system_configuration, mock_custom_configuration
|
||||
):
|
||||
"""Test initialization with customizable model only configuration"""
|
||||
# Arrange
|
||||
mock_provider_entity.configurate_methods = [ConfigurateMethod.CUSTOMIZABLE_MODEL]
|
||||
|
||||
# Act
|
||||
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
|
||||
config = ProviderConfiguration(
|
||||
tenant_id="test_tenant",
|
||||
provider=mock_provider_entity,
|
||||
preferred_provider_type=ProviderType.SYSTEM,
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=mock_system_configuration,
|
||||
custom_configuration=mock_custom_configuration,
|
||||
model_settings=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert ConfigurateMethod.PREDEFINED_MODEL in config.provider.configurate_methods
|
||||
|
||||
def test_get_current_credentials_with_restricted_models(self, provider_configuration):
|
||||
"""Test getting credentials with model restrictions"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.SYSTEM
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-3.5-turbo")
|
||||
|
||||
# Assert
|
||||
assert credentials is not None
|
||||
assert "openai_api_key" in credentials
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_specific_provider_credential_success(self, mock_session, provider_configuration):
|
||||
"""Test getting specific provider credential successfully"""
|
||||
# Arrange
|
||||
credential_id = "test_credential_id"
|
||||
mock_credential = Mock()
|
||||
mock_credential.encrypted_config = '{"openai_api_key": "encrypted_key"}'
|
||||
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_credential
|
||||
|
||||
# Act
|
||||
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
|
||||
mock_get.return_value = {"openai_api_key": "test_key"}
|
||||
result = provider_configuration._get_specific_provider_credential(credential_id)
|
||||
|
||||
# Assert
|
||||
assert result == {"openai_api_key": "test_key"}
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_specific_provider_credential_not_found(self, mock_session, provider_configuration):
|
||||
"""Test getting specific provider credential when not found"""
|
||||
# Arrange
|
||||
credential_id = "nonexistent_credential_id"
|
||||
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
|
||||
mock_get.return_value = None
|
||||
result = provider_configuration._get_specific_provider_credential(credential_id)
|
||||
assert result is None
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "test_key"}
|
||||
|
|
@ -1,190 +1,185 @@
|
|||
# from core.entities.provider_entities import ModelSettings
|
||||
# from core.model_runtime.entities.model_entities import ModelType
|
||||
# from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
# from core.provider_manager import ProviderManager
|
||||
# from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import ModelSettings
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
|
||||
|
||||
# def test__to_model_settings(mocker):
|
||||
# # Get all provider entities
|
||||
# model_provider_factory = ModelProviderFactory("test_tenant")
|
||||
# provider_entities = model_provider_factory.get_providers()
|
||||
@pytest.fixture
|
||||
def mock_provider_entity(mocker):
|
||||
mock_entity = mocker.Mock()
|
||||
mock_entity.provider = "openai"
|
||||
mock_entity.configurate_methods = ["predefined-model"]
|
||||
mock_entity.supported_model_types = [ModelType.LLM]
|
||||
|
||||
# provider_entity = None
|
||||
# for provider in provider_entities:
|
||||
# if provider.provider == "openai":
|
||||
# provider_entity = provider
|
||||
mock_entity.model_credential_schema = mocker.Mock()
|
||||
mock_entity.model_credential_schema.credential_form_schemas = []
|
||||
|
||||
# # Mocking the inputs
|
||||
# provider_model_settings = [
|
||||
# ProviderModelSetting(
|
||||
# id="id",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# enabled=True,
|
||||
# load_balancing_enabled=True,
|
||||
# )
|
||||
# ]
|
||||
# load_balancing_model_configs = [
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id1",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="__inherit__",
|
||||
# encrypted_config=None,
|
||||
# enabled=True,
|
||||
# ),
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id2",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="first",
|
||||
# encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
# enabled=True,
|
||||
# ),
|
||||
# ]
|
||||
|
||||
# mocker.patch(
|
||||
# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
# )
|
||||
|
||||
# provider_manager = ProviderManager()
|
||||
|
||||
# # Running the method
|
||||
# result = provider_manager._to_model_settings(provider_entity,
|
||||
# provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# # Asserting that the result is as expected
|
||||
# assert len(result) == 1
|
||||
# assert isinstance(result[0], ModelSettings)
|
||||
# assert result[0].model == "gpt-4"
|
||||
# assert result[0].model_type == ModelType.LLM
|
||||
# assert result[0].enabled is True
|
||||
# assert len(result[0].load_balancing_configs) == 2
|
||||
# assert result[0].load_balancing_configs[0].name == "__inherit__"
|
||||
# assert result[0].load_balancing_configs[1].name == "first"
|
||||
return mock_entity
|
||||
|
||||
|
||||
# def test__to_model_settings_only_one_lb(mocker):
|
||||
# # Get all provider entities
|
||||
# model_provider_factory = ModelProviderFactory("test_tenant")
|
||||
# provider_entities = model_provider_factory.get_providers()
|
||||
def test__to_model_settings(mocker, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
|
||||
# provider_entity = None
|
||||
# for provider in provider_entities:
|
||||
# if provider.provider == "openai":
|
||||
# provider_entity = provider
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
# # Mocking the inputs
|
||||
# provider_model_settings = [
|
||||
# ProviderModelSetting(
|
||||
# id="id",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# enabled=True,
|
||||
# load_balancing_enabled=True,
|
||||
# )
|
||||
# ]
|
||||
# load_balancing_model_configs = [
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id1",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="__inherit__",
|
||||
# encrypted_config=None,
|
||||
# enabled=True,
|
||||
# )
|
||||
# ]
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# mocker.patch(
|
||||
# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
# )
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# provider_manager = ProviderManager()
|
||||
|
||||
# # Running the method
|
||||
# result = provider_manager._to_model_settings(
|
||||
# provider_entity, provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# # Asserting that the result is as expected
|
||||
# assert len(result) == 1
|
||||
# assert isinstance(result[0], ModelSettings)
|
||||
# assert result[0].model == "gpt-4"
|
||||
# assert result[0].model_type == ModelType.LLM
|
||||
# assert result[0].enabled is True
|
||||
# assert len(result[0].load_balancing_configs) == 0
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 2
|
||||
assert result[0].load_balancing_configs[0].name == "__inherit__"
|
||||
assert result[0].load_balancing_configs[1].name == "first"
|
||||
|
||||
|
||||
# def test__to_model_settings_lb_disabled(mocker):
|
||||
# # Get all provider entities
|
||||
# model_provider_factory = ModelProviderFactory("test_tenant")
|
||||
# provider_entities = model_provider_factory.get_providers()
|
||||
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
)
|
||||
]
|
||||
|
||||
# provider_entity = None
|
||||
# for provider in provider_entities:
|
||||
# if provider.provider == "openai":
|
||||
# provider_entity = provider
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
# # Mocking the inputs
|
||||
# provider_model_settings = [
|
||||
# ProviderModelSetting(
|
||||
# id="id",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# enabled=True,
|
||||
# load_balancing_enabled=False,
|
||||
# )
|
||||
# ]
|
||||
# load_balancing_model_configs = [
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id1",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="__inherit__",
|
||||
# encrypted_config=None,
|
||||
# enabled=True,
|
||||
# ),
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id2",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="first",
|
||||
# encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
# enabled=True,
|
||||
# ),
|
||||
# ]
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# mocker.patch(
|
||||
# "core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
# return_value={"openai_api_key": "fake_key"}
|
||||
# )
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# provider_manager = ProviderManager()
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
# # Running the method
|
||||
# result = provider_manager._to_model_settings(provider_entity,
|
||||
# provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# # Asserting that the result is as expected
|
||||
# assert len(result) == 1
|
||||
# assert isinstance(result[0], ModelSettings)
|
||||
# assert result[0].model == "gpt-4"
|
||||
# assert result[0].model_type == ModelType.LLM
|
||||
# assert result[0].enabled is True
|
||||
# assert len(result[0].load_balancing_configs) == 0
|
||||
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
|
|
|||
|
|
@ -8,4 +8,4 @@ cd "$SCRIPT_DIR/.."
|
|||
|
||||
uv --directory api run \
|
||||
celery -A app.celery worker \
|
||||
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
|
||||
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
|
|
|
|||
|
|
@ -1250,6 +1250,10 @@ QUEUE_MONITOR_ALERT_EMAILS=
|
|||
# Monitor interval in minutes, default is 30 minutes
|
||||
QUEUE_MONITOR_INTERVAL=30
|
||||
|
||||
# Swagger UI configuration
|
||||
SWAGGER_UI_ENABLED=true
|
||||
SWAGGER_UI_PATH=/swagger-ui.html
|
||||
|
||||
# Celery schedule tasks configuration
|
||||
ENABLE_CLEAN_EMBEDDING_CACHE_TASK=false
|
||||
ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
|
||||
|
|
|
|||
|
|
@ -566,6 +566,8 @@ x-shared-env: &shared-api-worker-env
|
|||
QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200}
|
||||
QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-}
|
||||
QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30}
|
||||
SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-true}
|
||||
SWAGGER_UI_PATH: ${SWAGGER_UI_PATH:-/swagger-ui.html}
|
||||
ENABLE_CLEAN_EMBEDDING_CACHE_TASK: ${ENABLE_CLEAN_EMBEDDING_CACHE_TASK:-false}
|
||||
ENABLE_CLEAN_UNUSED_DATASETS_TASK: ${ENABLE_CLEAN_UNUSED_DATASETS_TASK:-false}
|
||||
ENABLE_CREATE_TIDB_SERVERLESS_TASK: ${ENABLE_CREATE_TIDB_SERVERLESS_TASK:-false}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ COPY --from=packages /app/web/ .
|
|||
COPY . .
|
||||
|
||||
ENV NODE_OPTIONS="--max-old-space-size=4096"
|
||||
RUN pnpm build
|
||||
RUN pnpm build:docker
|
||||
|
||||
|
||||
# production stage
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
|
|||
const [isShowDeleteConfirm, setIsShowDeleteConfirm] = useState(false)
|
||||
const [hoverArea, setHoverArea] = useState<string>('left')
|
||||
|
||||
const [onAvatarError, setOnAvatarError] = useState(false)
|
||||
|
||||
const handleImageInput: OnImageInput = useCallback(async (isCropped: boolean, fileOrTempUrl: string | File, croppedAreaPixels?: Area, fileName?: string) => {
|
||||
setInputImageInfo(
|
||||
isCropped
|
||||
|
|
@ -98,10 +100,15 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
|
|||
<>
|
||||
<div>
|
||||
<div className="group relative">
|
||||
<Avatar {...props} />
|
||||
<Avatar {...props} onError={(x: boolean) => setOnAvatarError(x)} />
|
||||
<div
|
||||
className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black/50 opacity-0 transition-opacity group-hover:opacity-100"
|
||||
onClick={() => hoverArea === 'right' ? setIsShowDeleteConfirm(true) : setIsShowAvatarPicker(true)}
|
||||
onClick={() => {
|
||||
if (hoverArea === 'right' && !onAvatarError)
|
||||
setIsShowDeleteConfirm(true)
|
||||
else
|
||||
setIsShowAvatarPicker(true)
|
||||
}}
|
||||
onMouseMove={(e) => {
|
||||
const rect = e.currentTarget.getBoundingClientRect()
|
||||
const x = e.clientX - rect.left
|
||||
|
|
@ -109,12 +116,15 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
|
|||
setHoverArea(isRight ? 'right' : 'left')
|
||||
}}
|
||||
>
|
||||
{hoverArea === 'right' ? <span className="text-xs text-white">
|
||||
<RiDeleteBin5Line />
|
||||
</span> : <span className="text-xs text-white">
|
||||
<RiPencilLine />
|
||||
</span>}
|
||||
|
||||
{hoverArea === 'right' && !onAvatarError ? (
|
||||
<span className="text-xs text-white">
|
||||
<RiDeleteBin5Line />
|
||||
</span>
|
||||
) : (
|
||||
<span className="text-xs text-white">
|
||||
<RiPencilLine />
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ const AppIconPicker: FC<AppIconPickerProps> = ({
|
|||
<button
|
||||
key={tab.key}
|
||||
className={`
|
||||
flex h-8 flex-1 shrink-0 items-center justify-center rounded-xl p-2 text-sm font-medium
|
||||
flex h-8 flex-1 shrink-0 items-center justify-center rounded-lg p-2 text-sm font-medium
|
||||
${activeTab === tab.key && 'bg-components-main-nav-nav-button-bg-active shadow-md'}
|
||||
`}
|
||||
onClick={() => setActiveTab(tab.key as AppIconType)}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ export type AvatarProps = {
|
|||
size?: number
|
||||
className?: string
|
||||
textClassName?: string
|
||||
onError?: (x: boolean) => void
|
||||
}
|
||||
const Avatar = ({
|
||||
name,
|
||||
|
|
@ -15,6 +16,7 @@ const Avatar = ({
|
|||
size = 30,
|
||||
className,
|
||||
textClassName,
|
||||
onError,
|
||||
}: AvatarProps) => {
|
||||
const avatarClassName = 'shrink-0 flex items-center rounded-full bg-primary-600'
|
||||
const style = { width: `${size}px`, height: `${size}px`, fontSize: `${size}px`, lineHeight: `${size}px` }
|
||||
|
|
@ -22,6 +24,7 @@ const Avatar = ({
|
|||
|
||||
const handleError = () => {
|
||||
setImgError(true)
|
||||
onError?.(true)
|
||||
}
|
||||
|
||||
if (avatar && !imgError) {
|
||||
|
|
@ -32,6 +35,7 @@ const Avatar = ({
|
|||
alt={name}
|
||||
src={avatar}
|
||||
onError={handleError}
|
||||
onLoad={() => onError?.(false)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ const BaseField = ({
|
|||
inputClassName,
|
||||
formSchema,
|
||||
field,
|
||||
disabled,
|
||||
disabled: propsDisabled,
|
||||
}: BaseFieldProps) => {
|
||||
const renderI18nObject = useRenderI18nObject()
|
||||
const {
|
||||
|
|
@ -67,8 +67,10 @@ const BaseField = ({
|
|||
selfFormProps,
|
||||
onChange,
|
||||
tooltip,
|
||||
disabled: formSchemaDisabled,
|
||||
} = formSchema
|
||||
const type = typeof typeOrFn === 'function' ? typeOrFn(field.form) : typeOrFn
|
||||
const disabled = propsDisabled || formSchemaDisabled
|
||||
|
||||
const memorizedLabel = useMemo(() => {
|
||||
if (isValidElement(label))
|
||||
|
|
@ -107,7 +109,7 @@ const BaseField = ({
|
|||
})
|
||||
const memorizedOptions = useMemo(() => {
|
||||
return options?.filter((option) => {
|
||||
if (!option.show_on?.length)
|
||||
if (!option.show_on || option.show_on.length === 0)
|
||||
return true
|
||||
|
||||
return option.show_on.every((condition) => {
|
||||
|
|
@ -120,7 +122,7 @@ const BaseField = ({
|
|||
value: option.value,
|
||||
}
|
||||
}) || []
|
||||
}, [options, renderI18nObject])
|
||||
}, [options, renderI18nObject, optionValues])
|
||||
const value = useStore(field.form.store, s => s.values[field.name])
|
||||
const values = useStore(field.form.store, (s) => {
|
||||
return (Array.isArray(show_on) ? show_on : show_on(field.form)).reduce((acc, condition) => {
|
||||
|
|
@ -135,9 +137,11 @@ const BaseField = ({
|
|||
})
|
||||
}, [values, show_on, field.name])
|
||||
const handleChange = useCallback((value: any) => {
|
||||
if (disabled)
|
||||
return
|
||||
field.handleChange(value)
|
||||
onChange?.(field.form, value)
|
||||
}, [field, onChange])
|
||||
}, [field, onChange, disabled])
|
||||
|
||||
const selfProps = typeof selfFormProps === 'function' ? selfFormProps(field.form) : selfFormProps
|
||||
|
||||
|
|
@ -294,6 +298,7 @@ const BaseField = ({
|
|||
className={cn(
|
||||
'system-sm-regular hover:bg-components-option-card-option-hover-bg hover:border-components-option-card-option-hover-border flex h-8 flex-[1] grow cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg p-2 text-text-secondary',
|
||||
value === option.value && 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary shadow-xs',
|
||||
disabled && 'cursor-not-allowed opacity-50',
|
||||
inputClassName,
|
||||
formInputClassName,
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -1,34 +1,52 @@
|
|||
import { useCallback } from 'react'
|
||||
import {
|
||||
isValidElement,
|
||||
useCallback,
|
||||
} from 'react'
|
||||
import type { ReactNode } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { FormSchema } from '../types'
|
||||
import { useRenderI18nObject } from '@/hooks/use-i18n'
|
||||
|
||||
export const useGetValidators = () => {
|
||||
const { t } = useTranslation()
|
||||
const renderI18nObject = useRenderI18nObject()
|
||||
const getLabel = useCallback((label: string | Record<string, string> | ReactNode) => {
|
||||
if (isValidElement(label))
|
||||
return ''
|
||||
|
||||
if (typeof label === 'string')
|
||||
return label
|
||||
|
||||
if (typeof label === 'object' && label !== null)
|
||||
return renderI18nObject(label as Record<string, string>)
|
||||
}, [])
|
||||
const getValidators = useCallback((formSchema: FormSchema) => {
|
||||
const {
|
||||
name,
|
||||
validators,
|
||||
required,
|
||||
label,
|
||||
} = formSchema
|
||||
let mergedValidators = validators
|
||||
const memorizedLabel = getLabel(label)
|
||||
if (required && !validators) {
|
||||
mergedValidators = {
|
||||
onMount: ({ value }: any) => {
|
||||
if (!value)
|
||||
return t('common.errorMsg.fieldRequired', { field: name })
|
||||
return t('common.errorMsg.fieldRequired', { field: memorizedLabel || name })
|
||||
},
|
||||
onChange: ({ value }: any) => {
|
||||
if (!value)
|
||||
return t('common.errorMsg.fieldRequired', { field: name })
|
||||
return t('common.errorMsg.fieldRequired', { field: memorizedLabel || name })
|
||||
},
|
||||
onBlur: ({ value }: any) => {
|
||||
if (!value)
|
||||
return t('common.errorMsg.fieldRequired', { field: name })
|
||||
return t('common.errorMsg.fieldRequired', { field: memorizedLabel })
|
||||
},
|
||||
}
|
||||
}
|
||||
return mergedValidators
|
||||
}, [t])
|
||||
}, [t, getLabel])
|
||||
|
||||
return {
|
||||
getValidators,
|
||||
|
|
|
|||
|
|
@ -70,6 +70,8 @@ export type FormSchema = {
|
|||
validators?: AnyValidators
|
||||
selfFormProps?: ((form: AnyFormApi) => Record<string, any>) | Record<string, any>
|
||||
onChange?: (form: AnyFormApi, v: any) => void
|
||||
showRadioUI?: boolean
|
||||
disabled?: boolean
|
||||
}
|
||||
|
||||
export type FormValues = Record<string, any>
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import IndexFailed from '@/app/components/datasets/common/document-status-with-a
|
|||
import { useProviderContext } from '@/context/provider-context'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useDocumentList, useInvalidDocumentDetailKey, useInvalidDocumentList } from '@/service/knowledge/use-document'
|
||||
import { useIndexStatus } from './list'
|
||||
import { useInvalid } from '@/service/use-base'
|
||||
import { useChildSegmentListKey, useSegmentListKey } from '@/service/knowledge/use-segment'
|
||||
import useDocumentListQueryState from './hooks/use-document-list-query-state'
|
||||
|
|
@ -32,6 +33,9 @@ import DatasetMetadataDrawer from '../metadata/metadata-dataset/dataset-metadata
|
|||
import StatusWithAction from '../common/document-status-with-action/status-with-action'
|
||||
import { useDocLink } from '@/context/i18n'
|
||||
import { useFetchDefaultProcessRule } from '@/service/knowledge/use-create-dataset'
|
||||
import { SimpleSelect } from '../../base/select'
|
||||
import StatusItem from './detail/completed/status-item'
|
||||
import type { Item } from '@/app/components/base/select'
|
||||
|
||||
const FolderPlusIcon = ({ className }: React.SVGProps<SVGElement>) => {
|
||||
return <svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
|
|
@ -91,6 +95,8 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
|
|||
const isFreePlan = plan.type === 'sandbox'
|
||||
const [inputValue, setInputValue] = useState<string>('') // the input value
|
||||
const [searchValue, setSearchValue] = useState<string>('')
|
||||
const [statusFilter, setStatusFilter] = useState<Item>({ value: 'all', name: 'All Status' })
|
||||
const DOC_INDEX_STATUS_MAP = useIndexStatus()
|
||||
|
||||
// Use the new hook for URL state management
|
||||
const { query, updateQuery } = useDocumentListQueryState()
|
||||
|
|
@ -107,6 +113,18 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
|
|||
const embeddingAvailable = !!dataset?.embedding_available
|
||||
const debouncedSearchValue = useDebounce(searchValue, { wait: 500 })
|
||||
|
||||
const statusFilterItems: Item[] = useMemo(() => [
|
||||
{ value: 'all', name: 'All Status' },
|
||||
{ value: 'queuing', name: DOC_INDEX_STATUS_MAP.queuing.text },
|
||||
{ value: 'indexing', name: DOC_INDEX_STATUS_MAP.indexing.text },
|
||||
{ value: 'paused', name: DOC_INDEX_STATUS_MAP.paused.text },
|
||||
{ value: 'error', name: DOC_INDEX_STATUS_MAP.error.text },
|
||||
{ value: 'available', name: DOC_INDEX_STATUS_MAP.available.text },
|
||||
{ value: 'enabled', name: DOC_INDEX_STATUS_MAP.enabled.text },
|
||||
{ value: 'disabled', name: DOC_INDEX_STATUS_MAP.disabled.text },
|
||||
{ value: 'archived', name: DOC_INDEX_STATUS_MAP.archived.text },
|
||||
], [DOC_INDEX_STATUS_MAP, t])
|
||||
|
||||
// Initialize search value from URL on mount
|
||||
useEffect(() => {
|
||||
if (query.keyword) {
|
||||
|
|
@ -322,14 +340,28 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
|
|||
</div>
|
||||
<div className='flex flex-1 flex-col px-6 py-4'>
|
||||
<div className='flex flex-wrap items-center justify-between'>
|
||||
<Input
|
||||
showLeftIcon
|
||||
showClearIcon
|
||||
wrapperClassName='!w-[200px]'
|
||||
value={inputValue}
|
||||
onChange={e => handleInputChange(e.target.value)}
|
||||
onClear={() => handleInputChange('')}
|
||||
/>
|
||||
<div className='flex items-center gap-2'>
|
||||
<SimpleSelect
|
||||
placeholder={t('datasetDocuments.list.table.header.status')}
|
||||
onSelect={(item) => {
|
||||
setStatusFilter(item)
|
||||
}}
|
||||
items={statusFilterItems}
|
||||
defaultValue={statusFilter.value}
|
||||
wrapperClassName='w-[160px] h-8'
|
||||
renderOption={({ item, selected }) => <StatusItem item={item} selected={selected} />}
|
||||
optionClassName='p-0'
|
||||
notClearable
|
||||
/>
|
||||
<Input
|
||||
showLeftIcon
|
||||
showClearIcon
|
||||
wrapperClassName='!w-[200px]'
|
||||
value={inputValue}
|
||||
onChange={e => handleInputChange(e.target.value)}
|
||||
onClear={() => handleInputChange('')}
|
||||
/>
|
||||
</div>
|
||||
<div className='flex !h-8 items-center justify-center gap-2'>
|
||||
{!isFreePlan && <AutoDisabledDocument datasetId={datasetId} />}
|
||||
<IndexFailed datasetId={datasetId} />
|
||||
|
|
@ -372,6 +404,8 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
|
|||
onUpdate={handleUpdate}
|
||||
selectedIds={selectedIds}
|
||||
onSelectedIdChange={setSelectedIds}
|
||||
statusFilter={statusFilter}
|
||||
onStatusFilterChange={setStatusFilter}
|
||||
pagination={{
|
||||
total,
|
||||
limit,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import { pick, uniq } from 'lodash-es'
|
|||
import {
|
||||
RiArchive2Line,
|
||||
RiDeleteBinLine,
|
||||
RiDownloadLine,
|
||||
RiEditLine,
|
||||
RiEqualizer2Line,
|
||||
RiLoopLeftLine,
|
||||
|
|
@ -31,11 +30,11 @@ import Popover from '@/app/components/base/popover'
|
|||
import Confirm from '@/app/components/base/confirm'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import Toast, { ToastContext } from '@/app/components/base/toast'
|
||||
import type { Item } from '@/app/components/base/select'
|
||||
import type { ColorMap, IndicatorProps } from '@/app/components/header/indicator'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
import { asyncRunSafe } from '@/utils'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
import { useDocumentDownload } from '@/service/knowledge/use-document'
|
||||
import NotionIcon from '@/app/components/base/notion-icon'
|
||||
import ProgressBar from '@/app/components/base/progress-bar'
|
||||
import { ChunkingMode, DataSourceType, DocumentActionType, type DocumentDisplayStatus, type SimpleDocumentDetail } from '@/models/datasets'
|
||||
|
|
@ -189,7 +188,6 @@ export const OperationAction: FC<{
|
|||
scene?: 'list' | 'detail'
|
||||
className?: string
|
||||
}> = ({ embeddingAvailable, datasetId, detail, onUpdate, scene = 'list', className = '' }) => {
|
||||
const downloadDocument = useDocumentDownload()
|
||||
const { id, enabled = false, archived = false, data_source_type, display_status } = detail || {}
|
||||
const [showModal, setShowModal] = useState(false)
|
||||
const [deleting, setDeleting] = useState(false)
|
||||
|
|
@ -298,32 +296,6 @@ export const OperationAction: FC<{
|
|||
)}
|
||||
{embeddingAvailable && (
|
||||
<>
|
||||
<Tooltip
|
||||
popupContent={t('datasetDocuments.list.action.download')}
|
||||
popupClassName='text-text-secondary system-xs-medium'
|
||||
needsDelay={false}
|
||||
>
|
||||
<button
|
||||
className={cn('mr-2 cursor-pointer rounded-lg',
|
||||
!isListScene
|
||||
? 'shadow-shadow-3 border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg p-2 shadow-xs backdrop-blur-[5px] hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover'
|
||||
: 'p-0.5 hover:bg-state-base-hover')}
|
||||
onClick={() => {
|
||||
downloadDocument.mutateAsync({
|
||||
datasetId,
|
||||
documentId: detail.id,
|
||||
}).then((response) => {
|
||||
if (response.download_url)
|
||||
window.location.href = response.download_url
|
||||
}).catch((error) => {
|
||||
console.error(error)
|
||||
notify({ type: 'error', message: t('common.actionMsg.downloadFailed') })
|
||||
})
|
||||
}}
|
||||
>
|
||||
<RiDownloadLine className='h-4 w-4 text-components-button-secondary-text' />
|
||||
</button>
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
popupContent={t('datasetDocuments.list.action.settings')}
|
||||
popupClassName='text-text-secondary system-xs-medium'
|
||||
|
|
@ -455,6 +427,8 @@ type IDocumentListProps = {
|
|||
pagination: PaginationProps
|
||||
onUpdate: () => void
|
||||
onManageMetadata: () => void
|
||||
statusFilter: Item
|
||||
onStatusFilterChange: (filter: string) => void
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -469,6 +443,7 @@ const DocumentList: FC<IDocumentListProps> = ({
|
|||
pagination,
|
||||
onUpdate,
|
||||
onManageMetadata,
|
||||
statusFilter,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { formatTime } = useTimestamp()
|
||||
|
|
@ -480,6 +455,7 @@ const DocumentList: FC<IDocumentListProps> = ({
|
|||
const [localDocs, setLocalDocs] = useState<LocalDoc[]>(documents)
|
||||
const [sortField, setSortField] = useState<'name' | 'word_count' | 'hit_count' | 'created_at' | null>('created_at')
|
||||
const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc')
|
||||
|
||||
const {
|
||||
isShowEditModal,
|
||||
showEditModal,
|
||||
|
|
@ -494,12 +470,22 @@ const DocumentList: FC<IDocumentListProps> = ({
|
|||
})
|
||||
|
||||
useEffect(() => {
|
||||
let filteredDocs = documents
|
||||
|
||||
if (statusFilter.value !== 'all') {
|
||||
filteredDocs = filteredDocs.filter(doc =>
|
||||
typeof doc.display_status === 'string'
|
||||
&& typeof statusFilter.value === 'string'
|
||||
&& doc.display_status.toLowerCase() === statusFilter.value.toLowerCase(),
|
||||
)
|
||||
}
|
||||
|
||||
if (!sortField) {
|
||||
setLocalDocs(documents)
|
||||
setLocalDocs(filteredDocs)
|
||||
return
|
||||
}
|
||||
|
||||
const sortedDocs = [...documents].sort((a, b) => {
|
||||
const sortedDocs = [...filteredDocs].sort((a, b) => {
|
||||
let aValue: any
|
||||
let bValue: any
|
||||
|
||||
|
|
@ -535,7 +521,7 @@ const DocumentList: FC<IDocumentListProps> = ({
|
|||
})
|
||||
|
||||
setLocalDocs(sortedDocs)
|
||||
}, [documents, sortField, sortOrder])
|
||||
}, [documents, sortField, sortOrder, statusFilter])
|
||||
|
||||
const handleSort = (field: 'name' | 'word_count' | 'hit_count' | 'created_at') => {
|
||||
if (sortField === field) {
|
||||
|
|
@ -692,7 +678,11 @@ const DocumentList: FC<IDocumentListProps> = ({
|
|||
{doc?.data_source_type === DataSourceType.FILE && <FileTypeIcon type={extensionToFileType(doc?.data_source_info?.upload_file?.extension ?? fileType)} className='mr-1.5' />}
|
||||
{doc?.data_source_type === DataSourceType.WEB && <Globe01 className='mr-1.5 mt-[-3px] inline-flex align-middle' />}
|
||||
</div>
|
||||
<span className='grow-1 truncate text-sm'>{doc.name}</span>
|
||||
<Tooltip
|
||||
popupContent={doc.name}
|
||||
>
|
||||
<span className='grow-1 truncate text-sm'>{doc.name}</span>
|
||||
</Tooltip>
|
||||
<div className='hidden shrink-0 group-hover:ml-auto group-hover:flex'>
|
||||
<Tooltip
|
||||
popupContent={t('datasetDocuments.list.table.rename')}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import Button from '@/app/components/base/button'
|
|||
import type { LangGeniusVersionResponse } from '@/models/common'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import DifyLogo from '@/app/components/base/logo/dify-logo'
|
||||
import { noop } from 'lodash-es'
|
||||
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
type IAccountSettingProps = {
|
||||
|
|
@ -27,11 +27,11 @@ export default function AccountAbout({
|
|||
return (
|
||||
<Modal
|
||||
isShow
|
||||
onClose={noop}
|
||||
onClose={onCancel}
|
||||
className='!w-[480px] !max-w-[480px] !px-6 !py-4'
|
||||
>
|
||||
<div>
|
||||
<div className='absolute right-4 top-4 flex h-8 w-8 cursor-pointer items-center justify-center' onClick={onCancel}>
|
||||
<div className='relative'>
|
||||
<div className='absolute right-0 top-0 flex h-8 w-8 cursor-pointer items-center justify-center' onClick={onCancel}>
|
||||
<RiCloseLine className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
<div className='flex flex-col items-center gap-4 py-8'>
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ export enum ModelStatusEnum {
|
|||
quotaExceeded = 'quota-exceeded',
|
||||
noPermission = 'no-permission',
|
||||
disabled = 'disabled',
|
||||
credentialRemoved = 'credential-removed',
|
||||
}
|
||||
|
||||
export const MODEL_STATUS_TEXT: { [k: string]: TypeWithI18N } = {
|
||||
|
|
@ -153,6 +154,7 @@ export type ModelItem = {
|
|||
model_properties: Record<string, string | number>
|
||||
load_balancing_enabled: boolean
|
||||
deprecated?: boolean
|
||||
has_invalid_load_balancing_configs?: boolean
|
||||
}
|
||||
|
||||
export enum PreferredProviderTypeEnum {
|
||||
|
|
@ -181,6 +183,29 @@ export type QuotaConfiguration = {
|
|||
is_valid: boolean
|
||||
}
|
||||
|
||||
export type Credential = {
|
||||
credential_id: string
|
||||
credential_name?: string
|
||||
from_enterprise?: boolean
|
||||
not_allowed_to_use?: boolean
|
||||
}
|
||||
|
||||
export type CustomModel = {
|
||||
model: string
|
||||
model_type: ModelTypeEnum
|
||||
}
|
||||
|
||||
export type CustomModelCredential = CustomModel & {
|
||||
credentials?: Record<string, any>
|
||||
available_model_credentials?: Credential[]
|
||||
current_credential_id?: string
|
||||
}
|
||||
|
||||
export type CredentialWithModel = Credential & {
|
||||
model: string
|
||||
model_type: ModelTypeEnum
|
||||
}
|
||||
|
||||
export type ModelProvider = {
|
||||
provider: string
|
||||
label: TypeWithI18N
|
||||
|
|
@ -207,12 +232,17 @@ export type ModelProvider = {
|
|||
preferred_provider_type: PreferredProviderTypeEnum
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum
|
||||
current_credential_id?: string
|
||||
current_credential_name?: string
|
||||
available_credentials?: Credential[]
|
||||
custom_models?: CustomModelCredential[]
|
||||
}
|
||||
system_configuration: {
|
||||
enabled: boolean
|
||||
current_quota_type: CurrentSystemQuotaTypeEnum
|
||||
quota_configurations: QuotaConfiguration[]
|
||||
}
|
||||
allow_custom_token?: boolean
|
||||
}
|
||||
|
||||
export type Model = {
|
||||
|
|
@ -272,9 +302,24 @@ export type ModelLoadBalancingConfigEntry = {
|
|||
in_cooldown?: boolean
|
||||
/** cooldown time (in seconds) */
|
||||
ttl?: number
|
||||
credential_id?: string
|
||||
}
|
||||
|
||||
export type ModelLoadBalancingConfig = {
|
||||
enabled: boolean
|
||||
configs: ModelLoadBalancingConfigEntry[]
|
||||
}
|
||||
|
||||
export type ProviderCredential = {
|
||||
credentials: Record<string, any>
|
||||
name: string
|
||||
credential_id: string
|
||||
}
|
||||
|
||||
export type ModelCredential = {
|
||||
credentials: Record<string, any>
|
||||
load_balancing: ModelLoadBalancingConfig
|
||||
available_credentials: Credential[]
|
||||
current_credential_id?: string
|
||||
current_credential_name?: string
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ import {
|
|||
import useSWR, { useSWRConfig } from 'swr'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import type {
|
||||
Credential,
|
||||
CustomConfigurationModelFixedFields,
|
||||
CustomModel,
|
||||
DefaultModel,
|
||||
DefaultModelResponse,
|
||||
Model,
|
||||
|
|
@ -77,16 +79,17 @@ export const useProviderCredentialsAndLoadBalancing = (
|
|||
configurationMethod: ConfigurationMethodEnum,
|
||||
configured?: boolean,
|
||||
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
|
||||
credentialId?: string,
|
||||
) => {
|
||||
const { data: predefinedFormSchemasValue, mutate: mutatePredefined } = useSWR(
|
||||
(configurationMethod === ConfigurationMethodEnum.predefinedModel && configured)
|
||||
? `/workspaces/current/model-providers/${provider}/credentials`
|
||||
const { data: predefinedFormSchemasValue, mutate: mutatePredefined, isLoading: isPredefinedLoading } = useSWR(
|
||||
(configurationMethod === ConfigurationMethodEnum.predefinedModel && configured && credentialId)
|
||||
? `/workspaces/current/model-providers/${provider}/credentials${credentialId ? `?credential_id=${credentialId}` : ''}`
|
||||
: null,
|
||||
fetchModelProviderCredentials,
|
||||
)
|
||||
const { data: customFormSchemasValue, mutate: mutateCustomized } = useSWR(
|
||||
(configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields)
|
||||
? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}`
|
||||
const { data: customFormSchemasValue, mutate: mutateCustomized, isLoading: isCustomizedLoading } = useSWR(
|
||||
(configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields && credentialId)
|
||||
? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}${credentialId ? `&credential_id=${credentialId}` : ''}`
|
||||
: null,
|
||||
fetchModelProviderCredentials,
|
||||
)
|
||||
|
|
@ -102,6 +105,7 @@ export const useProviderCredentialsAndLoadBalancing = (
|
|||
: undefined
|
||||
}, [
|
||||
configurationMethod,
|
||||
credentialId,
|
||||
currentCustomConfigurationModelFixedFields,
|
||||
customFormSchemasValue?.credentials,
|
||||
predefinedFormSchemasValue?.credentials,
|
||||
|
|
@ -119,6 +123,7 @@ export const useProviderCredentialsAndLoadBalancing = (
|
|||
: customFormSchemasValue
|
||||
)?.load_balancing,
|
||||
mutate,
|
||||
isLoading: isPredefinedLoading || isCustomizedLoading,
|
||||
}
|
||||
// as ([Record<string, string | boolean | undefined> | undefined, ModelLoadBalancingConfig | undefined])
|
||||
}
|
||||
|
|
@ -313,40 +318,59 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
|
|||
}
|
||||
}
|
||||
|
||||
export const useModelModalHandler = () => {
|
||||
const setShowModelModal = useModalContextSelector(state => state.setShowModelModal)
|
||||
export const useRefreshModel = () => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const updateModelProviders = useUpdateModelProviders()
|
||||
const updateModelList = useUpdateModelList()
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const handleRefreshModel = useCallback((provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => {
|
||||
updateModelProviders()
|
||||
|
||||
provider.supported_model_types.forEach((type) => {
|
||||
updateModelList(type)
|
||||
})
|
||||
|
||||
if (configurationMethod === ConfigurationMethodEnum.customizableModel
|
||||
&& provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
|
||||
eventEmitter?.emit({
|
||||
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
|
||||
payload: provider.provider,
|
||||
} as any)
|
||||
|
||||
if (CustomConfigurationModelFixedFields?.__model_type)
|
||||
updateModelList(CustomConfigurationModelFixedFields.__model_type)
|
||||
}
|
||||
}, [eventEmitter, updateModelList, updateModelProviders])
|
||||
|
||||
return {
|
||||
handleRefreshModel,
|
||||
}
|
||||
}
|
||||
|
||||
export const useModelModalHandler = () => {
|
||||
const setShowModelModal = useModalContextSelector(state => state.setShowModelModal)
|
||||
const { handleRefreshModel } = useRefreshModel()
|
||||
|
||||
return (
|
||||
provider: ModelProvider,
|
||||
configurationMethod: ConfigurationMethodEnum,
|
||||
CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
|
||||
isModelCredential?: boolean,
|
||||
credential?: Credential,
|
||||
model?: CustomModel,
|
||||
onUpdate?: () => void,
|
||||
) => {
|
||||
setShowModelModal({
|
||||
payload: {
|
||||
currentProvider: provider,
|
||||
currentConfigurationMethod: configurationMethod,
|
||||
currentCustomConfigurationModelFixedFields: CustomConfigurationModelFixedFields,
|
||||
isModelCredential,
|
||||
credential,
|
||||
model,
|
||||
},
|
||||
onSaveCallback: () => {
|
||||
updateModelProviders()
|
||||
|
||||
provider.supported_model_types.forEach((type) => {
|
||||
updateModelList(type)
|
||||
})
|
||||
|
||||
if (configurationMethod === ConfigurationMethodEnum.customizableModel
|
||||
&& provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
|
||||
eventEmitter?.emit({
|
||||
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
|
||||
payload: provider.provider,
|
||||
} as any)
|
||||
|
||||
if (CustomConfigurationModelFixedFields?.__model_type)
|
||||
updateModelList(CustomConfigurationModelFixedFields.__model_type)
|
||||
}
|
||||
handleRefreshModel(provider, configurationMethod, CustomConfigurationModelFixedFields)
|
||||
onUpdate?.()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ import {
|
|||
import SystemModelSelector from './system-model-selector'
|
||||
import ProviderAddedCard from './provider-added-card'
|
||||
import type {
|
||||
ConfigurationMethodEnum,
|
||||
CustomConfigurationModelFixedFields,
|
||||
ModelProvider,
|
||||
} from './declarations'
|
||||
import {
|
||||
|
|
@ -18,7 +16,6 @@ import {
|
|||
} from './declarations'
|
||||
import {
|
||||
useDefaultModel,
|
||||
useModelModalHandler,
|
||||
} from './hooks'
|
||||
import InstallFromMarketplace from './install-from-marketplace'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
|
|
@ -84,8 +81,6 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
|||
return [filteredConfiguredProviders, filteredNotConfiguredProviders]
|
||||
}, [configuredProviders, debouncedSearchText, notConfiguredProviders])
|
||||
|
||||
const handleOpenModal = useModelModalHandler()
|
||||
|
||||
return (
|
||||
<div className='relative -mt-2 pt-1'>
|
||||
<div className={cn('mb-2 flex items-center')}>
|
||||
|
|
@ -126,7 +121,6 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
|||
<ProviderAddedCard
|
||||
key={provider.provider}
|
||||
provider={provider}
|
||||
onOpenModal={(configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => handleOpenModal(provider, configurationMethod, currentCustomConfigurationModelFixedFields)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
|
@ -140,7 +134,6 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
|||
notConfigured
|
||||
key={provider.provider}
|
||||
provider={provider}
|
||||
onOpenModal={(configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => handleOpenModal(provider, configurationMethod, currentCustomConfigurationModelFixedFields)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,115 @@
|
|||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { RiAddLine } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Authorized } from '@/app/components/header/account-setting/model-provider-page/model-auth'
|
||||
import cn from '@/utils/classnames'
|
||||
import type {
|
||||
Credential,
|
||||
CustomModelCredential,
|
||||
ModelCredential,
|
||||
ModelProvider,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
|
||||
type AddCredentialInLoadBalancingProps = {
|
||||
provider: ModelProvider
|
||||
model: CustomModelCredential
|
||||
configurationMethod: ConfigurationMethodEnum
|
||||
modelCredential: ModelCredential
|
||||
onSelectCredential: (credential: Credential) => void
|
||||
onUpdate?: () => void
|
||||
}
|
||||
const AddCredentialInLoadBalancing = ({
|
||||
provider,
|
||||
model,
|
||||
configurationMethod,
|
||||
modelCredential,
|
||||
onSelectCredential,
|
||||
onUpdate,
|
||||
}: AddCredentialInLoadBalancingProps) => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
available_credentials,
|
||||
} = modelCredential
|
||||
const customModel = configurationMethod === ConfigurationMethodEnum.customizableModel
|
||||
const notAllowCustomCredential = provider.allow_custom_token === false
|
||||
|
||||
const ButtonComponent = useMemo(() => {
|
||||
const Item = (
|
||||
<div className={cn(
|
||||
'system-sm-medium flex h-8 items-center rounded-lg px-3 text-text-accent hover:bg-state-base-hover',
|
||||
notAllowCustomCredential && 'cursor-not-allowed opacity-50',
|
||||
)}>
|
||||
<RiAddLine className='mr-2 h-4 w-4' />
|
||||
{
|
||||
customModel
|
||||
? t('common.modelProvider.auth.addCredential')
|
||||
: t('common.modelProvider.auth.addApiKey')
|
||||
}
|
||||
</div>
|
||||
)
|
||||
|
||||
if (notAllowCustomCredential) {
|
||||
return (
|
||||
<Tooltip
|
||||
asChild
|
||||
popupContent={t('plugin.auth.credentialUnavailable')}
|
||||
>
|
||||
{Item}
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
return Item
|
||||
}, [notAllowCustomCredential, t, customModel])
|
||||
|
||||
const renderTrigger = useCallback((open?: boolean) => {
|
||||
const Item = (
|
||||
<div className={cn(
|
||||
'system-sm-medium flex h-8 items-center rounded-lg px-3 text-text-accent hover:bg-state-base-hover',
|
||||
open && 'bg-state-base-hover',
|
||||
)}>
|
||||
<RiAddLine className='mr-2 h-4 w-4' />
|
||||
{
|
||||
customModel
|
||||
? t('common.modelProvider.auth.addCredential')
|
||||
: t('common.modelProvider.auth.addApiKey')
|
||||
}
|
||||
</div>
|
||||
)
|
||||
|
||||
return Item
|
||||
}, [t, customModel])
|
||||
|
||||
if (!available_credentials?.length)
|
||||
return ButtonComponent
|
||||
|
||||
return (
|
||||
<Authorized
|
||||
provider={provider}
|
||||
renderTrigger={renderTrigger}
|
||||
items={[
|
||||
{
|
||||
title: customModel ? t('common.modelProvider.auth.modelCredentials') : t('common.modelProvider.auth.apiKeys'),
|
||||
model: customModel ? model : undefined,
|
||||
credentials: available_credentials ?? [],
|
||||
},
|
||||
]}
|
||||
configurationMethod={configurationMethod}
|
||||
currentCustomConfigurationModelFixedFields={customModel ? {
|
||||
__model_name: model.model,
|
||||
__model_type: model.model_type,
|
||||
} : undefined}
|
||||
onItemClick={onSelectCredential}
|
||||
placement='bottom-start'
|
||||
onUpdate={onUpdate}
|
||||
isModelCredential={customModel}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(AddCredentialInLoadBalancing)
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiAddCircleFill,
|
||||
} from '@remixicon/react'
|
||||
import {
|
||||
Button,
|
||||
} from '@/app/components/base/button'
|
||||
import type {
|
||||
CustomConfigurationModelFixedFields,
|
||||
ModelProvider,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import Authorized from './authorized'
|
||||
import {
|
||||
useAuth,
|
||||
useCustomModels,
|
||||
} from './hooks'
|
||||
import cn from '@/utils/classnames'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
|
||||
type AddCustomModelProps = {
|
||||
provider: ModelProvider,
|
||||
configurationMethod: ConfigurationMethodEnum,
|
||||
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
|
||||
}
|
||||
const AddCustomModel = ({
|
||||
provider,
|
||||
configurationMethod,
|
||||
currentCustomConfigurationModelFixedFields,
|
||||
}: AddCustomModelProps) => {
|
||||
const { t } = useTranslation()
|
||||
const customModels = useCustomModels(provider)
|
||||
const noModels = !customModels.length
|
||||
const {
|
||||
handleOpenModal,
|
||||
} = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, true)
|
||||
const notAllowCustomCredential = provider.allow_custom_token === false
|
||||
const handleClick = useCallback(() => {
|
||||
if (notAllowCustomCredential)
|
||||
return
|
||||
|
||||
handleOpenModal()
|
||||
}, [handleOpenModal, notAllowCustomCredential])
|
||||
const ButtonComponent = useMemo(() => {
|
||||
const Item = (
|
||||
<Button
|
||||
variant='ghost-accent'
|
||||
size='small'
|
||||
onClick={handleClick}
|
||||
className={cn(
|
||||
notAllowCustomCredential && 'cursor-not-allowed opacity-50',
|
||||
)}
|
||||
>
|
||||
<RiAddCircleFill className='mr-1 h-3.5 w-3.5' />
|
||||
{t('common.modelProvider.addModel')}
|
||||
</Button>
|
||||
)
|
||||
if (notAllowCustomCredential) {
|
||||
return (
|
||||
<Tooltip
|
||||
asChild
|
||||
popupContent={t('plugin.auth.credentialUnavailable')}
|
||||
>
|
||||
{Item}
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
return Item
|
||||
}, [handleClick, notAllowCustomCredential, t])
|
||||
|
||||
const renderTrigger = useCallback((open?: boolean) => {
|
||||
const Item = (
|
||||
<Button
|
||||
variant='ghost'
|
||||
size='small'
|
||||
className={cn(
|
||||
open && 'bg-components-button-ghost-bg-hover',
|
||||
)}
|
||||
>
|
||||
<RiAddCircleFill className='mr-1 h-3.5 w-3.5' />
|
||||
{t('common.modelProvider.addModel')}
|
||||
</Button>
|
||||
)
|
||||
return Item
|
||||
}, [t])
|
||||
|
||||
if (noModels)
|
||||
return ButtonComponent
|
||||
|
||||
return (
|
||||
<Authorized
|
||||
provider={provider}
|
||||
configurationMethod={ConfigurationMethodEnum.customizableModel}
|
||||
items={customModels.map(model => ({
|
||||
model,
|
||||
credentials: model.available_model_credentials ?? [],
|
||||
}))}
|
||||
renderTrigger={renderTrigger}
|
||||
isModelCredential
|
||||
enableAddModelCredential
|
||||
bottomAddModelCredentialText={t('common.modelProvider.auth.addNewModel')}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(AddCustomModel)
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
} from 'react'
|
||||
import { RiAddLine } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import CredentialItem from './credential-item'
|
||||
import type {
|
||||
Credential,
|
||||
CustomModel,
|
||||
CustomModelCredential,
|
||||
} from '../../declarations'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
|
||||
type AuthorizedItemProps = {
|
||||
model?: CustomModelCredential
|
||||
title?: string
|
||||
disabled?: boolean
|
||||
onDelete?: (credential?: Credential, model?: CustomModel) => void
|
||||
onEdit?: (credential?: Credential, model?: CustomModel) => void
|
||||
showItemSelectedIcon?: boolean
|
||||
selectedCredentialId?: string
|
||||
credentials: Credential[]
|
||||
onItemClick?: (credential: Credential, model?: CustomModel) => void
|
||||
enableAddModelCredential?: boolean
|
||||
notAllowCustomCredential?: boolean
|
||||
}
|
||||
export const AuthorizedItem = ({
|
||||
model,
|
||||
title,
|
||||
credentials,
|
||||
disabled,
|
||||
onDelete,
|
||||
onEdit,
|
||||
showItemSelectedIcon,
|
||||
selectedCredentialId,
|
||||
onItemClick,
|
||||
enableAddModelCredential,
|
||||
notAllowCustomCredential,
|
||||
}: AuthorizedItemProps) => {
|
||||
const { t } = useTranslation()
|
||||
const handleEdit = useCallback((credential?: Credential) => {
|
||||
onEdit?.(credential, model)
|
||||
}, [onEdit, model])
|
||||
const handleDelete = useCallback((credential?: Credential) => {
|
||||
onDelete?.(credential, model)
|
||||
}, [onDelete, model])
|
||||
const handleItemClick = useCallback((credential: Credential) => {
|
||||
onItemClick?.(credential, model)
|
||||
}, [onItemClick, model])
|
||||
|
||||
return (
|
||||
<div className='p-1'>
|
||||
<div
|
||||
className='flex h-9 items-center'
|
||||
>
|
||||
<div className='h-5 w-5 shrink-0'></div>
|
||||
<div
|
||||
className='system-md-medium mx-1 grow truncate text-text-primary'
|
||||
title={title ?? model?.model}
|
||||
>
|
||||
{title ?? model?.model}
|
||||
</div>
|
||||
{
|
||||
enableAddModelCredential && !notAllowCustomCredential && (
|
||||
<Tooltip
|
||||
asChild
|
||||
popupContent={t('common.modelProvider.auth.addModelCredential')}
|
||||
>
|
||||
<Button
|
||||
className='h-6 w-6 shrink-0 rounded-full p-0'
|
||||
size='small'
|
||||
variant='secondary-accent'
|
||||
onClick={() => handleEdit?.()}
|
||||
>
|
||||
<RiAddLine className='h-4 w-4' />
|
||||
</Button>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
{
|
||||
credentials.map(credential => (
|
||||
<CredentialItem
|
||||
key={credential.credential_id}
|
||||
credential={credential}
|
||||
disabled={disabled}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
showSelectedIcon={showItemSelectedIcon}
|
||||
selectedCredentialId={selectedCredentialId}
|
||||
onItemClick={handleItemClick}
|
||||
/>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(AuthorizedItem)
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import {
|
||||
memo,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiCheckLine,
|
||||
RiDeleteBinLine,
|
||||
RiEqualizer2Line,
|
||||
} from '@remixicon/react'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import cn from '@/utils/classnames'
|
||||
import type { Credential } from '../../declarations'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
|
||||
type CredentialItemProps = {
|
||||
credential: Credential
|
||||
disabled?: boolean
|
||||
onDelete?: (credential: Credential) => void
|
||||
onEdit?: (credential?: Credential) => void
|
||||
onItemClick?: (credential: Credential) => void
|
||||
disableRename?: boolean
|
||||
disableEdit?: boolean
|
||||
disableDelete?: boolean
|
||||
showSelectedIcon?: boolean
|
||||
selectedCredentialId?: string
|
||||
}
|
||||
const CredentialItem = ({
|
||||
credential,
|
||||
disabled,
|
||||
onDelete,
|
||||
onEdit,
|
||||
onItemClick,
|
||||
disableRename,
|
||||
disableEdit,
|
||||
disableDelete,
|
||||
showSelectedIcon,
|
||||
selectedCredentialId,
|
||||
}: CredentialItemProps) => {
|
||||
const { t } = useTranslation()
|
||||
const showAction = useMemo(() => {
|
||||
return !(disableRename && disableEdit && disableDelete)
|
||||
}, [disableRename, disableEdit, disableDelete])
|
||||
|
||||
const Item = (
|
||||
<div
|
||||
key={credential.credential_id}
|
||||
className={cn(
|
||||
'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover',
|
||||
(disabled || credential.not_allowed_to_use) && 'cursor-not-allowed opacity-50',
|
||||
)}
|
||||
onClick={() => {
|
||||
if (disabled || credential.not_allowed_to_use)
|
||||
return
|
||||
onItemClick?.(credential)
|
||||
}}
|
||||
>
|
||||
<div className='flex w-0 grow items-center space-x-1.5'>
|
||||
{
|
||||
showSelectedIcon && (
|
||||
<div className='h-4 w-4'>
|
||||
{
|
||||
selectedCredentialId === credential.credential_id && (
|
||||
<RiCheckLine className='h-4 w-4 text-text-accent' />
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<Indicator className='ml-2 mr-1.5 shrink-0' />
|
||||
<div
|
||||
className='system-md-regular truncate text-text-secondary'
|
||||
title={credential.credential_name}
|
||||
>
|
||||
{credential.credential_name}
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
credential.from_enterprise && (
|
||||
<Badge className='shrink-0'>
|
||||
Enterprise
|
||||
</Badge>
|
||||
)
|
||||
}
|
||||
{
|
||||
showAction && (
|
||||
<div className='ml-2 hidden shrink-0 items-center group-hover:flex'>
|
||||
{
|
||||
!disableEdit && !credential.not_allowed_to_use && !credential.from_enterprise && (
|
||||
<Tooltip popupContent={t('common.operation.edit')}>
|
||||
<ActionButton
|
||||
disabled={disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
onEdit?.(credential)
|
||||
}}
|
||||
>
|
||||
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
|
||||
</ActionButton>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
{
|
||||
!disableDelete && !credential.from_enterprise && (
|
||||
<Tooltip popupContent={t('common.operation.delete')}>
|
||||
<ActionButton
|
||||
className='hover:bg-transparent'
|
||||
disabled={disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
onDelete?.(credential)
|
||||
}}
|
||||
>
|
||||
<RiDeleteBinLine className='h-4 w-4 text-text-tertiary hover:text-text-destructive' />
|
||||
</ActionButton>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
|
||||
if (credential.not_allowed_to_use) {
|
||||
return (
|
||||
<Tooltip popupContent={t('plugin.auth.customCredentialUnavailable')}>
|
||||
{Item}
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
return Item
|
||||
}
|
||||
|
||||
export default memo(CredentialItem)
|
||||
|
|
@ -0,0 +1,222 @@
|
|||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
useMemo,
|
||||
useState,
|
||||
} from 'react'
|
||||
import {
|
||||
RiAddLine,
|
||||
RiEqualizer2Line,
|
||||
} from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import type {
|
||||
PortalToFollowElemOptions,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import Button from '@/app/components/base/button'
|
||||
import cn from '@/utils/classnames'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import type {
|
||||
ConfigurationMethodEnum,
|
||||
Credential,
|
||||
CustomConfigurationModelFixedFields,
|
||||
CustomModel,
|
||||
ModelProvider,
|
||||
} from '../../declarations'
|
||||
import { useAuth } from '../hooks'
|
||||
import AuthorizedItem from './authorized-item'
|
||||
|
||||
type AuthorizedProps = {
|
||||
provider: ModelProvider,
|
||||
configurationMethod: ConfigurationMethodEnum,
|
||||
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
|
||||
isModelCredential?: boolean
|
||||
items: {
|
||||
title?: string
|
||||
model?: CustomModel
|
||||
credentials: Credential[]
|
||||
}[]
|
||||
selectedCredential?: Credential
|
||||
disabled?: boolean
|
||||
renderTrigger?: (open?: boolean) => React.ReactNode
|
||||
isOpen?: boolean
|
||||
onOpenChange?: (open: boolean) => void
|
||||
offset?: PortalToFollowElemOptions['offset']
|
||||
placement?: PortalToFollowElemOptions['placement']
|
||||
triggerPopupSameWidth?: boolean
|
||||
popupClassName?: string
|
||||
showItemSelectedIcon?: boolean
|
||||
onUpdate?: () => void
|
||||
onItemClick?: (credential: Credential, model?: CustomModel) => void
|
||||
enableAddModelCredential?: boolean
|
||||
bottomAddModelCredentialText?: string
|
||||
}
|
||||
const Authorized = ({
|
||||
provider,
|
||||
configurationMethod,
|
||||
currentCustomConfigurationModelFixedFields,
|
||||
items,
|
||||
isModelCredential,
|
||||
selectedCredential,
|
||||
disabled,
|
||||
renderTrigger,
|
||||
isOpen,
|
||||
onOpenChange,
|
||||
offset = 8,
|
||||
placement = 'bottom-end',
|
||||
triggerPopupSameWidth = false,
|
||||
popupClassName,
|
||||
showItemSelectedIcon,
|
||||
onUpdate,
|
||||
onItemClick,
|
||||
enableAddModelCredential,
|
||||
bottomAddModelCredentialText,
|
||||
}: AuthorizedProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [isLocalOpen, setIsLocalOpen] = useState(false)
|
||||
const mergedIsOpen = isOpen ?? isLocalOpen
|
||||
const setMergedIsOpen = useCallback((open: boolean) => {
|
||||
if (onOpenChange)
|
||||
onOpenChange(open)
|
||||
|
||||
setIsLocalOpen(open)
|
||||
}, [onOpenChange])
|
||||
const {
|
||||
openConfirmDelete,
|
||||
closeConfirmDelete,
|
||||
doingAction,
|
||||
handleActiveCredential,
|
||||
handleConfirmDelete,
|
||||
deleteCredentialId,
|
||||
handleOpenModal,
|
||||
} = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate)
|
||||
|
||||
const handleEdit = useCallback((credential?: Credential, model?: CustomModel) => {
|
||||
handleOpenModal(credential, model)
|
||||
setMergedIsOpen(false)
|
||||
}, [handleOpenModal, setMergedIsOpen])
|
||||
|
||||
const handleItemClick = useCallback((credential: Credential, model?: CustomModel) => {
|
||||
if (onItemClick)
|
||||
onItemClick(credential, model)
|
||||
else
|
||||
handleActiveCredential(credential, model)
|
||||
|
||||
setMergedIsOpen(false)
|
||||
}, [handleActiveCredential, onItemClick, setMergedIsOpen])
|
||||
const notAllowCustomCredential = provider.allow_custom_token === false
|
||||
|
||||
const Trigger = useMemo(() => {
|
||||
const Item = (
|
||||
<Button
|
||||
className='grow'
|
||||
size='small'
|
||||
>
|
||||
<RiEqualizer2Line className='mr-1 h-3.5 w-3.5' />
|
||||
{t('common.operation.config')}
|
||||
</Button>
|
||||
)
|
||||
return Item
|
||||
}, [t])
|
||||
|
||||
return (
|
||||
<>
|
||||
<PortalToFollowElem
|
||||
open={mergedIsOpen}
|
||||
onOpenChange={setMergedIsOpen}
|
||||
placement={placement}
|
||||
offset={offset}
|
||||
triggerPopupSameWidth={triggerPopupSameWidth}
|
||||
>
|
||||
<PortalToFollowElemTrigger
|
||||
onClick={() => {
|
||||
setMergedIsOpen(!mergedIsOpen)
|
||||
}}
|
||||
asChild
|
||||
>
|
||||
{
|
||||
renderTrigger
|
||||
? renderTrigger(mergedIsOpen)
|
||||
: Trigger
|
||||
}
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-[100]'>
|
||||
<div className={cn(
|
||||
'w-[360px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg',
|
||||
popupClassName,
|
||||
)}>
|
||||
<div className='max-h-[304px] overflow-y-auto'>
|
||||
{
|
||||
items.map((item, index) => (
|
||||
<AuthorizedItem
|
||||
key={index}
|
||||
title={item.title}
|
||||
model={item.model}
|
||||
credentials={item.credentials}
|
||||
disabled={disabled}
|
||||
onDelete={openConfirmDelete}
|
||||
onEdit={handleEdit}
|
||||
showItemSelectedIcon={showItemSelectedIcon}
|
||||
selectedCredentialId={selectedCredential?.credential_id}
|
||||
onItemClick={handleItemClick}
|
||||
enableAddModelCredential={enableAddModelCredential}
|
||||
notAllowCustomCredential={notAllowCustomCredential}
|
||||
/>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
<div className='h-[1px] bg-divider-subtle'></div>
|
||||
{
|
||||
isModelCredential && !notAllowCustomCredential && (
|
||||
<div
|
||||
onClick={() => handleEdit(
|
||||
undefined,
|
||||
currentCustomConfigurationModelFixedFields
|
||||
? {
|
||||
model: currentCustomConfigurationModelFixedFields.__model_name,
|
||||
model_type: currentCustomConfigurationModelFixedFields.__model_type,
|
||||
}
|
||||
: undefined,
|
||||
)}
|
||||
className='system-xs-medium flex h-[30px] cursor-pointer items-center px-3 text-text-accent-light-mode-only'
|
||||
>
|
||||
<RiAddLine className='mr-1 h-4 w-4' />
|
||||
{bottomAddModelCredentialText ?? t('common.modelProvider.auth.addModelCredential')}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
!isModelCredential && !notAllowCustomCredential && (
|
||||
<div className='p-2'>
|
||||
<Button
|
||||
onClick={() => handleEdit()}
|
||||
className='w-full'
|
||||
>
|
||||
{t('common.modelProvider.auth.addApiKey')}
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
{
|
||||
deleteCredentialId && (
|
||||
<Confirm
|
||||
isShow
|
||||
title={t('common.modelProvider.confirmDelete')}
|
||||
isDisabled={doingAction}
|
||||
onCancel={closeConfirmDelete}
|
||||
onConfirm={handleConfirmDelete}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(Authorized)
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
import { memo } from 'react'
|
||||
import {
|
||||
RiEqualizer2Line,
|
||||
RiScales3Line,
|
||||
} from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type ConfigModelProps = {
|
||||
onClick?: () => void
|
||||
loadBalancingEnabled?: boolean
|
||||
loadBalancingInvalid?: boolean
|
||||
credentialRemoved?: boolean
|
||||
}
|
||||
const ConfigModel = ({
|
||||
onClick,
|
||||
loadBalancingEnabled,
|
||||
loadBalancingInvalid,
|
||||
credentialRemoved,
|
||||
}: ConfigModelProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
if (loadBalancingInvalid) {
|
||||
return (
|
||||
<div
|
||||
className='system-2xs-medium-uppercase relative flex h-[18px] items-center rounded-[5px] border border-text-warning bg-components-badge-bg-dimm px-1.5 text-text-warning'
|
||||
onClick={onClick}
|
||||
>
|
||||
<RiScales3Line className='mr-0.5 h-3 w-3' />
|
||||
{t('common.modelProvider.auth.authorizationError')}
|
||||
<Indicator color='orange' className='absolute right-[-1px] top-[-1px] h-1.5 w-1.5' />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<Button
|
||||
variant='secondary'
|
||||
size='small'
|
||||
className={cn(
|
||||
'hidden shrink-0 group-hover:flex',
|
||||
credentialRemoved && 'flex',
|
||||
)}
|
||||
onClick={onClick}
|
||||
>
|
||||
{
|
||||
credentialRemoved && (
|
||||
<>
|
||||
{t('common.modelProvider.auth.credentialRemoved')}
|
||||
<Indicator color='red' className='ml-2' />
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
!loadBalancingEnabled && !credentialRemoved && !loadBalancingInvalid && (
|
||||
<>
|
||||
<RiEqualizer2Line className='mr-1 h-4 w-4' />
|
||||
{t('common.operation.config')}
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
loadBalancingEnabled && !credentialRemoved && !loadBalancingInvalid && (
|
||||
<>
|
||||
<RiScales3Line className='mr-1 h-4 w-4' />
|
||||
{t('common.modelProvider.auth.configLoadBalancing')}
|
||||
</>
|
||||
)
|
||||
}
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(ConfigModel)
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiEqualizer2Line,
|
||||
} from '@remixicon/react'
|
||||
import {
|
||||
Button,
|
||||
} from '@/app/components/base/button'
|
||||
import type {
|
||||
CustomConfigurationModelFixedFields,
|
||||
ModelProvider,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import Authorized from './authorized'
|
||||
import { useAuth, useCredentialStatus } from './hooks'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type ConfigProviderProps = {
|
||||
provider: ModelProvider,
|
||||
configurationMethod: ConfigurationMethodEnum,
|
||||
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
|
||||
}
|
||||
const ConfigProvider = ({
|
||||
provider,
|
||||
configurationMethod,
|
||||
currentCustomConfigurationModelFixedFields,
|
||||
}: ConfigProviderProps) => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
handleOpenModal,
|
||||
} = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields)
|
||||
const {
|
||||
hasCredential,
|
||||
authorized,
|
||||
current_credential_id,
|
||||
current_credential_name,
|
||||
available_credentials,
|
||||
} = useCredentialStatus(provider)
|
||||
const notAllowCustomCredential = provider.allow_custom_token === false
|
||||
const handleClick = useCallback(() => {
|
||||
if (!hasCredential && !notAllowCustomCredential)
|
||||
handleOpenModal()
|
||||
}, [handleOpenModal, hasCredential, notAllowCustomCredential])
|
||||
const ButtonComponent = useMemo(() => {
|
||||
const Item = (
|
||||
<Button
|
||||
className={cn('grow', notAllowCustomCredential && 'cursor-not-allowed opacity-50')}
|
||||
size='small'
|
||||
onClick={handleClick}
|
||||
variant={!authorized ? 'secondary-accent' : 'secondary'}
|
||||
>
|
||||
<RiEqualizer2Line className='mr-1 h-3.5 w-3.5' />
|
||||
{t('common.operation.setup')}
|
||||
</Button>
|
||||
)
|
||||
if (notAllowCustomCredential) {
|
||||
return (
|
||||
<Tooltip
|
||||
asChild
|
||||
popupContent={t('plugin.auth.credentialUnavailable')}
|
||||
>
|
||||
{Item}
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
return Item
|
||||
}, [handleClick, authorized, notAllowCustomCredential, t])
|
||||
|
||||
if (!hasCredential)
|
||||
return ButtonComponent
|
||||
|
||||
return (
|
||||
<Authorized
|
||||
provider={provider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={[
|
||||
{
|
||||
title: t('common.modelProvider.auth.apiKeys'),
|
||||
credentials: available_credentials ?? [],
|
||||
},
|
||||
]}
|
||||
selectedCredential={{
|
||||
credential_id: current_credential_id ?? '',
|
||||
credential_name: current_credential_name ?? '',
|
||||
}}
|
||||
showItemSelectedIcon
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(ConfigProvider)
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
export * from './use-model-form-schemas'
|
||||
export * from './use-credential-status'
|
||||
export * from './use-custom-models'
|
||||
export * from './use-auth'
|
||||
export * from './use-auth-service'
|
||||
export * from './use-credential-data'
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue